diff --git a/Cargo.lock b/Cargo.lock index 78b6fb7dc6..1dba57c3e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1698,6 +1698,26 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap 5.5.3", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.5", + "smallvec", + "spinning_top", +] + [[package]] name = "h2" version = "0.3.27" @@ -2568,10 +2588,12 @@ dependencies = [ "email_address", "futures", "glob", + "governor", "headers", "html5ever", "html5gum", "http 1.3.1", + "humantime-serde", "hyper 1.7.0", "ignore", "ip_network", @@ -2740,6 +2762,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -2750,6 +2778,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -3327,6 +3361,21 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.1+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.8" @@ -3456,6 +3505,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.9.2", +] + [[package]] name = "rayon" version = "1.11.0" @@ -4170,6 +4228,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/README.md b/README.md index 198a9a705b..600b80287c 100644 --- a/README.md +++ b/README.md @@ -373,6 +373,9 @@ Options: Do not show progress bar. This is recommended for non-interactive shells (e.g. for continuous integration) + --host-stats + Show per-host statistics at the end of the run + --extensions Test the specified file extensions for URIs when checking files locally. @@ -385,7 +388,12 @@ Options: --default-extension Default file extension to treat files without extensions as having. - This is useful for files without extensions or with unknown extensions. The extension will be used to determine the file type for processing. Examples: --default-extension md, --default-extension html + This is useful for files without extensions or with unknown extensions. + The extension will be used to determine the file type for processing. + + Examples: + --default-extension md + --default-extension html --cache Use request cache stored on disk at `.lycheecache` @@ -447,6 +455,28 @@ Options: [default: 128] + --host-concurrency + Default maximum concurrent requests per host (default: 10) + + This limits how many requests can be sent simultaneously to the same + host (domain/subdomain). This helps prevent overwhelming servers and + getting rate-limited. Each host is handled independently. + + Examples: + --host-concurrency 5 # Conservative for slow APIs + --host-concurrency 20 # Aggressive for fast APIs + + --request-interval + Minimum interval between requests to the same host (default: 100ms) + + Sets a baseline delay between consecutive requests to prevent + hammering servers. The adaptive algorithm may increase this based + on server responses (rate limits, errors). + + Examples: + --request-interval 50ms # Fast for robust APIs + --request-interval 1s # Conservative for rate-limited APIs + -T, --threads Number of threads to utilize. Defaults to number of cores available to the system diff --git a/lychee-bin/src/client.rs b/lychee-bin/src/client.rs index 4c99f6fe7c..665789e2e7 100644 --- a/lychee-bin/src/client.rs +++ b/lychee-bin/src/client.rs @@ -2,7 +2,10 @@ use crate::options::{Config, HeaderMapExt}; use crate::parse::{parse_duration_secs, parse_remaps}; use anyhow::{Context, Result}; use http::{HeaderMap, StatusCode}; -use lychee_lib::{Client, ClientBuilder}; +use lychee_lib::{ + Client, ClientBuilder, + ratelimit::{HostPool, RateLimitConfig}, +}; use regex::RegexSet; use reqwest_cookie_store::CookieStoreMutex; use std::sync::Arc; @@ -28,6 +31,35 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc>) - let headers = HeaderMap::from_header_pairs(&cfg.header)?; + // Create combined headers for HostPool (includes User-Agent + custom headers) + let mut combined_headers = headers.clone(); + combined_headers.insert( + http::header::USER_AGENT, + cfg.user_agent + .parse() + .context("Invalid User-Agent header")?, + ); + + // Create HostPool for rate limiting - always enabled for HTTP requests + let rate_limit_config = + RateLimitConfig::from_options(cfg.host_concurrency, cfg.request_interval); + let cache_max_age = if cfg.cache { 3600 } else { 0 }; // 1 hour if caching enabled, disabled otherwise + + let mut host_pool = HostPool::new( + rate_limit_config, + cfg.hosts.clone(), + cfg.max_concurrency, + cache_max_age, + combined_headers, + cfg.max_redirects, + Some(timeout), + cfg.insecure, + ); + + if let Some(cookie_jar) = cookie_jar { + host_pool = host_pool.with_cookie_jar(cookie_jar.clone()); + } + ClientBuilder::builder() .remaps(remaps) .base(cfg.base_url.clone()) @@ -55,6 +87,7 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc>) - .include_fragments(cfg.include_fragments) .fallback_extensions(cfg.fallback_extensions.clone()) .index_files(cfg.index_files.clone()) + .host_pool(Some(host_pool)) .build() .client() .context("Failed to create request client") diff --git a/lychee-bin/src/commands/check.rs b/lychee-bin/src/commands/check.rs index 0a266c6ffc..df93031589 100644 --- a/lychee-bin/src/commands/check.rs +++ b/lychee-bin/src/commands/check.rs @@ -27,7 +27,7 @@ use super::CommandParams; pub(crate) async fn check( params: CommandParams, -) -> Result<(ResponseStats, Arc, ExitCode)> +) -> Result<(ResponseStats, Arc, ExitCode, Client)> where S: futures::Stream>, { @@ -47,6 +47,7 @@ where let cache_ref = params.cache.clone(); let client = params.client; + let client_for_return = client.clone(); let cache = params.cache; let cache_exclude_status = params.cfg.cache_exclude_status.into_set(); let accept = params.cfg.accept.into(); @@ -120,7 +121,7 @@ where } else { ExitCode::LinkCheckFailure }; - Ok((stats, cache_ref, code)) + Ok((stats, cache_ref, code, client_for_return)) } async fn suggest_archived_links( @@ -287,6 +288,8 @@ async fn handle( accept: HashSet, ) -> Response { let uri = request.uri.clone(); + + // First check the persistent disk-based cache if let Some(v) = cache.get(&uri) { // Found a cached request // Overwrite cache status in case the URI is excluded in the @@ -300,18 +303,27 @@ async fn handle( // code. Status::from_cache_status(v.value().status, &accept) }; + + // Track cache hit in the per-host stats (only for network URIs) + if !uri.is_file() { + if let Err(e) = client.record_cache_hit(&uri) { + log::debug!("Failed to record cache hit for {uri}: {e}"); + } + } + return Response::new(uri.clone(), status, request.source); } - // Request was not cached; run a normal check + // Cache miss - track it and run a normal check (only for network URIs) + if !uri.is_file() { + if let Err(e) = client.record_cache_miss(&uri) { + log::debug!("Failed to record cache miss for {uri}: {e}"); + } + } + let response = check_url(client, request).await; - // - Never cache filesystem access as it is fast already so caching has no - // benefit. - // - Skip caching unsupported URLs as they might be supported in a - // future run. - // - Skip caching excluded links; they might not be excluded in the next run. - // - Skip caching links for which the status code has been explicitly excluded from the cache. + // Apply the same caching rules as before let status = response.status(); if ignore_cache(&uri, status, &cache_exclude_status) { return response; diff --git a/lychee-bin/src/formatters/host_stats/compact.rs b/lychee-bin/src/formatters/host_stats/compact.rs new file mode 100644 index 0000000000..121230e259 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/compact.rs @@ -0,0 +1,81 @@ +use anyhow::Result; +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use crate::formatters::color::{DIM, NORMAL, color}; +use lychee_lib::ratelimit::HostStats; + +use super::HostStatsFormatter; + +struct CompactHostStats { + host_stats: HashMap, +} + +impl Display for CompactHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.host_stats.is_empty() { + return Ok(()); + } + + writeln!(f)?; + writeln!(f, "šŸ“Š Per-host Statistics")?; + + let separator = "─".repeat(60); + color!(f, DIM, "{}", separator)?; + writeln!(f)?; + + let sorted_hosts = super::sort_host_stats(&self.host_stats); + + // Calculate optimal hostname width based on longest hostname + let max_hostname_len = sorted_hosts + .iter() + .map(|(hostname, _)| hostname.len()) + .max() + .unwrap_or(0); + let hostname_width = (max_hostname_len + 2).max(10); // At least 10 chars with padding + + for (hostname, stats) in sorted_hosts { + let median_time = stats + .median_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + let cache_hit_rate = stats.cache_hit_rate() * 100.0; + + color!( + f, + NORMAL, + "{:6} reqs │ {:>6.1}% success │ {:>8} median │ {:>6.1}% cached", + hostname, + stats.total_requests, + stats.success_rate() * 100.0, + median_time, + cache_hit_rate, + width = hostname_width + )?; + writeln!(f)?; + } + + Ok(()) + } +} + +pub(crate) struct Compact; + +impl Compact { + pub(crate) const fn new() -> Self { + Self + } +} + +impl HostStatsFormatter for Compact { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let compact = CompactHostStats { host_stats }; + Ok(Some(compact.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/detailed.rs b/lychee-bin/src/formatters/host_stats/detailed.rs new file mode 100644 index 0000000000..01bfd42bc8 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/detailed.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use lychee_lib::ratelimit::HostStats; + +use super::HostStatsFormatter; + +struct DetailedHostStats { + host_stats: HashMap, +} + +impl Display for DetailedHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.host_stats.is_empty() { + return Ok(()); + } + + writeln!(f, "\nšŸ“Š Per-host Statistics")?; + writeln!(f, "---------------------")?; + + let sorted_hosts = super::sort_host_stats(&self.host_stats); + + for (hostname, stats) in sorted_hosts { + writeln!(f, "\nHost: {hostname}")?; + writeln!(f, " Total requests: {}", stats.total_requests)?; + writeln!( + f, + " Successful: {} ({:.1}%)", + stats.successful_requests, + stats.success_rate() * 100.0 + )?; + + if stats.rate_limited > 0 { + writeln!( + f, + " Rate limited: {} (429 Too Many Requests)", + stats.rate_limited + )?; + } + if stats.client_errors > 0 { + writeln!(f, " Client errors (4xx): {}", stats.client_errors)?; + } + if stats.server_errors > 0 { + writeln!(f, " Server errors (5xx): {}", stats.server_errors)?; + } + + if let Some(median_time) = stats.median_request_time() { + writeln!( + f, + " Median response time: {:.0}ms", + median_time.as_millis() + )?; + } + + let cache_hit_rate = stats.cache_hit_rate(); + if cache_hit_rate > 0.0 { + writeln!(f, " Cache hit rate: {:.1}%", cache_hit_rate * 100.0)?; + writeln!( + f, + " Cache hits: {}, misses: {}", + stats.cache_hits, stats.cache_misses + )?; + } + } + + Ok(()) + } +} + +pub(crate) struct Detailed; + +impl Detailed { + pub(crate) const fn new() -> Self { + Self + } +} + +impl HostStatsFormatter for Detailed { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let detailed = DetailedHostStats { host_stats }; + Ok(Some(detailed.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/json.rs b/lychee-bin/src/formatters/host_stats/json.rs new file mode 100644 index 0000000000..24f7fe0d2e --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/json.rs @@ -0,0 +1,57 @@ +use anyhow::{Context, Result}; +use serde_json::json; +use std::collections::HashMap; + +use super::HostStatsFormatter; +use lychee_lib::ratelimit::HostStats; + +pub(crate) struct Json; + +impl Json { + pub(crate) const fn new() -> Self { + Self {} + } +} + +impl HostStatsFormatter for Json { + /// Format host stats as JSON object + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + // Convert HostStats to a more JSON-friendly format + let json_stats: HashMap = host_stats + .into_iter() + .map(|(hostname, stats)| { + let json_value = json!({ + "total_requests": stats.total_requests, + "successful_requests": stats.successful_requests, + "success_rate": stats.success_rate(), + "rate_limited": stats.rate_limited, + "client_errors": stats.client_errors, + "server_errors": stats.server_errors, + "median_request_time_ms": stats.median_request_time() + .map(|d| { + #[allow(clippy::cast_possible_truncation)] + let millis = d.as_millis() as u64; + millis + }), + "cache_hits": stats.cache_hits, + "cache_misses": stats.cache_misses, + "cache_hit_rate": stats.cache_hit_rate(), + "status_codes": stats.status_codes + }); + (hostname, json_value) + }) + .collect(); + + let output = json!({ + "host_statistics": json_stats + }); + + serde_json::to_string_pretty(&output) + .map(Some) + .context("Cannot format host stats as JSON") + } +} diff --git a/lychee-bin/src/formatters/host_stats/markdown.rs b/lychee-bin/src/formatters/host_stats/markdown.rs new file mode 100644 index 0000000000..8980066107 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/markdown.rs @@ -0,0 +1,92 @@ +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use super::HostStatsFormatter; +use anyhow::Result; +use lychee_lib::ratelimit::HostStats; +use tabled::{ + Table, Tabled, + settings::{Alignment, Modify, Style, object::Segment}, +}; + +#[derive(Tabled)] +struct HostStatsTableEntry { + #[tabled(rename = "Host")] + host: String, + #[tabled(rename = "Requests")] + requests: u64, + #[tabled(rename = "Success Rate")] + success_rate: String, + #[tabled(rename = "Median Time")] + median_time: String, + #[tabled(rename = "Cache Hit Rate")] + cache_hit_rate: String, +} + +fn host_stats_table(host_stats: &HashMap) -> String { + let sorted_hosts = super::sort_host_stats(host_stats); + + let entries: Vec = sorted_hosts + .into_iter() + .map(|(hostname, stats)| { + let median_time = stats + .median_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + HostStatsTableEntry { + host: hostname.clone(), + requests: stats.total_requests, + success_rate: format!("{:.1}%", stats.success_rate() * 100.0), + median_time, + cache_hit_rate: format!("{:.1}%", stats.cache_hit_rate() * 100.0), + } + }) + .collect(); + + if entries.is_empty() { + return String::new(); + } + + let style = Style::markdown(); + Table::new(entries) + .with(Modify::new(Segment::all()).with(Alignment::left())) + .with(style) + .to_string() +} + +struct MarkdownHostStats(HashMap); + +impl Display for MarkdownHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.0.is_empty() { + return Ok(()); + } + + writeln!(f, "\n## Per-host Statistics")?; + writeln!(f)?; + writeln!(f, "{}", host_stats_table(&self.0))?; + + Ok(()) + } +} + +pub(crate) struct Markdown; + +impl Markdown { + pub(crate) const fn new() -> Self { + Self {} + } +} + +impl HostStatsFormatter for Markdown { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let markdown = MarkdownHostStats(host_stats); + Ok(Some(markdown.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/mod.rs b/lychee-bin/src/formatters/host_stats/mod.rs new file mode 100644 index 0000000000..8c312bfdd5 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/mod.rs @@ -0,0 +1,28 @@ +mod compact; +mod detailed; +mod json; +mod markdown; + +pub(crate) use compact::Compact; +pub(crate) use detailed::Detailed; +pub(crate) use json::Json; +pub(crate) use markdown::Markdown; + +use anyhow::Result; +use lychee_lib::ratelimit::HostStats; +use std::collections::HashMap; + +/// Trait for formatting per-host statistics in different output formats +pub(crate) trait HostStatsFormatter { + /// Format the host statistics and return them as a string + fn format(&self, host_stats: HashMap) -> Result>; +} + +/// Sort host statistics by request count (descending order) +/// This matches the display order we want in the output +fn sort_host_stats(host_stats: &HashMap) -> Vec<(&String, &HostStats)> { + let mut sorted_hosts: Vec<_> = host_stats.iter().collect(); + // Sort by total requests (descending) + sorted_hosts.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_requests)); + sorted_hosts +} diff --git a/lychee-bin/src/formatters/mod.rs b/lychee-bin/src/formatters/mod.rs index ddb82cf78c..2e2dfad51d 100644 --- a/lychee-bin/src/formatters/mod.rs +++ b/lychee-bin/src/formatters/mod.rs @@ -1,11 +1,12 @@ pub(crate) mod color; pub(crate) mod duration; +pub(crate) mod host_stats; pub(crate) mod log; pub(crate) mod response; pub(crate) mod stats; pub(crate) mod suggestion; -use self::{response::ResponseFormatter, stats::StatsFormatter}; +use self::{host_stats::HostStatsFormatter, response::ResponseFormatter, stats::StatsFormatter}; use crate::options::{OutputMode, StatsFormat}; use supports_color::Stream; @@ -29,6 +30,19 @@ pub(crate) fn get_stats_formatter( } } +/// Create a host stats formatter based on the given format and mode options +pub(crate) fn get_host_stats_formatter( + format: &StatsFormat, + _mode: &OutputMode, +) -> Box { + match format { + StatsFormat::Compact | StatsFormat::Raw => Box::new(host_stats::Compact::new()), // Use compact for raw + StatsFormat::Detailed => Box::new(host_stats::Detailed::new()), + StatsFormat::Json => Box::new(host_stats::Json::new()), + StatsFormat::Markdown => Box::new(host_stats::Markdown::new()), + } +} + /// Create a response formatter based on the given format option pub(crate) fn get_response_formatter(mode: &OutputMode) -> Box { // Checks if color is supported in current environment or NO_COLOR is set (https://no-color.org) diff --git a/lychee-bin/src/host_stats.rs b/lychee-bin/src/host_stats.rs new file mode 100644 index 0000000000..981d6eee8e --- /dev/null +++ b/lychee-bin/src/host_stats.rs @@ -0,0 +1,29 @@ +use anyhow::{Context, Result}; + +use crate::{formatters::get_host_stats_formatter, options::Config}; + +/// Display per-host statistics if requested +pub(crate) fn display_per_host_statistics( + client: &lychee_lib::Client, + config: &Config, +) -> Result<()> { + if !config.host_stats { + return Ok(()); + } + + let host_stats = client.host_stats(); + let host_stats_formatter = get_host_stats_formatter(&config.format, &config.mode); + + if let Some(formatted_host_stats) = host_stats_formatter.format(host_stats)? { + if let Some(output) = &config.output { + // For file output, append to the existing output + let mut file_content = std::fs::read_to_string(output).unwrap_or_default(); + file_content.push_str(&formatted_host_stats); + std::fs::write(output, file_content) + .context("Cannot write host stats to output file")?; + } else { + print!("{formatted_host_stats}"); + } + } + Ok(()) +} diff --git a/lychee-bin/src/main.rs b/lychee-bin/src/main.rs index b915652ad2..7aadef395f 100644 --- a/lychee-bin/src/main.rs +++ b/lychee-bin/src/main.rs @@ -86,6 +86,7 @@ mod client; mod commands; mod files_from; mod formatters; +mod host_stats; mod options; mod parse; mod stats; @@ -96,6 +97,7 @@ use crate::formatters::duration::Duration; use crate::{ cache::{Cache, StoreExt}, formatters::stats::StatsFormatter, + host_stats::display_per_host_statistics, options::{Config, LYCHEE_CACHE_FILE, LYCHEE_IGNORE_FILE, LycheeOptions}, }; @@ -378,7 +380,7 @@ async fn run(opts: &LycheeOptions) -> Result { let exit_code = if opts.config.dump { commands::dump(params).await? } else { - let (stats, cache, exit_code) = commands::check(params).await?; + let (stats, cache, exit_code, client) = commands::check(params).await?; let github_issues = stats .error_map @@ -406,6 +408,9 @@ async fn run(opts: &LycheeOptions) -> Result { } } + // Display per-host statistics if requested + display_per_host_statistics(&client, &opts.config)?; + if github_issues && opts.config.github_token.is_none() { warn!( "There were issues with GitHub URLs. You could try setting a GitHub token and running lychee again.", diff --git a/lychee-bin/src/options.rs b/lychee-bin/src/options.rs index 4a4a019a26..5d544442da 100644 --- a/lychee-bin/src/options.rs +++ b/lychee-bin/src/options.rs @@ -13,6 +13,7 @@ use lychee_lib::{ Base, BasicAuthSelector, DEFAULT_MAX_REDIRECTS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_WAIT_TIME_SECS, DEFAULT_TIMEOUT_SECS, DEFAULT_USER_AGENT, FileExtensions, FileType, Input, StatusCodeExcluder, StatusCodeSelector, archive::Archive, + ratelimit::HostConfig, }; use reqwest::tls; use secrecy::SecretString; @@ -416,6 +417,11 @@ pub(crate) struct Config { #[serde(default)] pub(crate) no_progress: bool, + /// Show per-host statistics at the end of the run + #[arg(long)] + #[serde(default)] + pub(crate) host_stats: bool, + /// A list of file extensions. Files not matching the specified extensions are skipped. /// /// E.g. a user can specify `--extensions html,htm,php,asp,aspx,jsp,cgi` @@ -439,8 +445,11 @@ specify both extensions explicitly." /// /// This is useful for files without extensions or with unknown extensions. /// The extension will be used to determine the file type for processing. - /// Examples: --default-extension md, --default-extension html - #[arg(long, value_name = "EXTENSION")] + /// + /// Examples: + /// --default-extension md + /// --default-extension html + #[arg(long, value_name = "EXTENSION", verbatim_doc_comment)] #[serde(default)] pub(crate) default_extension: Option, @@ -525,6 +534,32 @@ with a status code of 429, 500 and 501." #[serde(default = "max_concurrency")] pub(crate) max_concurrency: usize, + /// Default maximum concurrent requests per host (default: 10) + /// + /// This limits how many requests can be sent simultaneously to the same + /// host (domain/subdomain). This helps prevent overwhelming servers and + /// getting rate-limited. Each host is handled independently. + /// + /// Examples: + /// --host-concurrency 5 # Conservative for slow APIs + /// --host-concurrency 20 # Aggressive for fast APIs + #[arg(long = "host-concurrency", verbatim_doc_comment)] + #[serde(default)] + pub(crate) host_concurrency: Option, + + /// Minimum interval between requests to the same host (default: 100ms) + /// + /// Sets a baseline delay between consecutive requests to prevent + /// hammering servers. The adaptive algorithm may increase this based + /// on server responses (rate limits, errors). + /// + /// Examples: + /// --request-interval 50ms # Fast for robust APIs + /// --request-interval 1s # Conservative for rate-limited APIs + #[arg(long = "request-interval", value_parser = humantime::parse_duration, verbatim_doc_comment)] + #[serde(default)] + pub(crate) request_interval: Option, + /// Number of threads to utilize. /// Defaults to number of cores available to the system #[arg(short = 'T', long)] @@ -846,6 +881,11 @@ and existing cookies will be updated." #[arg(long)] #[serde(default)] pub(crate) include_wikilinks: bool, + + /// Host-specific configurations from config file + #[arg(skip)] + #[serde(default)] + pub(crate) hosts: HashMap, } impl Config { @@ -882,6 +922,11 @@ impl Config { self.github_token = toml.github_token; } + // Hosts configuration is only available in TOML for now (not in the CLI) + // That's because it's a bit complex to specify on the command line and + // we didn't come up with a good syntax for it yet. + self.hosts = toml.hosts; + // NOTE: if you see an error within this macro call, check to make sure that // that the fields provided to fold_in! match all the fields of the Config struct. fold_in! { @@ -892,6 +937,7 @@ impl Config { // Keys which are handled outside of fold_in ..header, ..github_token, + ..hosts, // Keys with defaults to assign accept: StatusCodeSelector::default(), @@ -903,6 +949,8 @@ impl Config { cache_exclude_status: StatusCodeExcluder::default(), cookie_jar: None, default_extension: None, + host_concurrency: None, + request_interval: None, dump: false, dump_inputs: false, exclude: Vec::::new(), @@ -917,6 +965,7 @@ impl Config { format: StatsFormat::default(), glob_ignore_case: false, hidden: false, + host_stats: false, include: Vec::::new(), include_fragments: false, include_mail: false, diff --git a/lychee-lib/Cargo.toml b/lychee-lib/Cargo.toml index 04f1c2ce8e..57958076fd 100644 --- a/lychee-lib/Cargo.toml +++ b/lychee-lib/Cargo.toml @@ -22,8 +22,10 @@ dashmap = { version = "6.1.0", features = ["serde"] } email_address = "0.2.9" futures = "0.3.31" glob = "0.3.3" +governor = "0.6.3" headers = "0.4.1" html5ever = "0.35.0" +humantime-serde = "1.1.1" html5gum = "0.8.0" http = "1.3.1" hyper = "1.6.0" diff --git a/lychee-lib/src/checker/website.rs b/lychee-lib/src/checker/website.rs index 2bccc86f05..578de19ac5 100644 --- a/lychee-lib/src/checker/website.rs +++ b/lychee-lib/src/checker/website.rs @@ -2,6 +2,7 @@ use crate::{ BasicAuthCredentials, ErrorKind, FileType, Status, Uri, chain::{Chain, ChainResult, ClientRequestChains, Handler, RequestChain}, quirks::Quirks, + ratelimit::HostPool, retry::RetryExt, types::{redirect_history::RedirectHistory, uri::github::GithubUri}, utils::fragment_checker::{FragmentChecker, FragmentInput}, @@ -10,7 +11,11 @@ use async_trait::async_trait; use http::{Method, StatusCode}; use octocrab::Octocrab; use reqwest::{Request, Response, header::CONTENT_TYPE}; -use std::{collections::HashSet, path::Path, time::Duration}; +use std::{ + collections::{HashMap, HashSet}, + path::Path, + time::Duration, +}; use url::Url; #[derive(Debug, Clone)] @@ -54,9 +59,61 @@ pub(crate) struct WebsiteChecker { /// Keep track of HTTP redirections for reporting redirect_history: RedirectHistory, + + /// Optional host pool for per-host rate limiting. + /// + /// When present, HTTP requests will be routed through this pool for + /// rate limiting. When None, requests go directly through `reqwest_client`. + host_pool: Option, } impl WebsiteChecker { + /// Get per-host statistics from the rate limiting system + /// + /// Returns a map of hostnames to their statistics, or an empty map + /// if host-based rate limiting is not enabled. + #[must_use] + pub(crate) fn host_stats(&self) -> HashMap { + self.host_pool + .as_ref() + .map_or_else(HashMap::default, HostPool::all_host_stats) + } + + /// Get cache statistics for all hosts + /// + /// Returns a map of hostnames to (`cache_size`, `hit_rate`), or an empty map + /// if host-based rate limiting is not enabled. + #[must_use] + pub(crate) fn cache_stats(&self) -> HashMap { + self.host_pool + .as_ref() + .map_or_else(HashMap::default, HostPool::cache_stats) + } + + /// Record a cache hit for the given URI in the host statistics + /// + /// This tracks that a request was served from the persistent cache + /// rather than making a network request. + pub(crate) fn record_cache_hit(&self, uri: &crate::Uri) -> crate::Result<()> { + if let Some(host_pool) = &self.host_pool { + host_pool.record_cache_hit(uri).map_err(Into::into) + } else { + Ok(()) // No host pool, nothing to track + } + } + + /// Record a cache miss for the given URI in the host statistics + /// + /// This tracks that a request could not be served from the persistent cache + /// and will require a network request (which may then use the in-memory cache). + pub(crate) fn record_cache_miss(&self, uri: &crate::Uri) -> crate::Result<()> { + if let Some(host_pool) = &self.host_pool { + host_pool.record_cache_miss(uri).map_err(Into::into) + } else { + Ok(()) // No host pool, nothing to track + } + } + #[allow(clippy::too_many_arguments)] pub(crate) fn new( method: reqwest::Method, @@ -69,6 +126,7 @@ impl WebsiteChecker { require_https: bool, plugin_request_chain: RequestChain, include_fragments: bool, + host_pool: Option, ) -> Self { Self { method, @@ -82,6 +140,7 @@ impl WebsiteChecker { require_https, include_fragments, fragment_checker: FragmentChecker::new(), + host_pool, } } @@ -109,7 +168,24 @@ impl WebsiteChecker { let method = request.method().clone(); let request_url = request.url().clone(); - match self.reqwest_client.execute(request).await { + // Use HostPool for rate limiting - always enabled for HTTP requests + let response_result = if let Some(host_pool) = &self.host_pool { + match host_pool.execute_request(request).await { + Ok(response) => Ok(response), + Err(crate::ratelimit::RateLimitError::NetworkError { source, .. }) => { + // Network errors should be handled the same as direct client errors + Err(source) + } + Err(e) => { + // Rate limiting specific errors + return Status::Error(ErrorKind::RateLimit(e)); + } + } + } else { + self.reqwest_client.execute(request).await + }; + + match response_result { Ok(response) => { let status = Status::new(&response, &self.accepted); // when `accept=200,429`, `status_code=429` will be treated as success diff --git a/lychee-lib/src/client.rs b/lychee-lib/src/client.rs index 63306bb91c..9c93bb49cf 100644 --- a/lychee-lib/src/client.rs +++ b/lychee-lib/src/client.rs @@ -304,6 +304,12 @@ pub struct ClientBuilder { /// early and return a status, so that subsequent chain items are /// skipped and the lychee-internal request chain is not activated. plugin_request_chain: RequestChain, + + /// Optional host pool for per-host rate limiting of HTTP requests. + /// + /// When provided, HTTP/HTTPS requests will be routed through this pool + /// for rate limiting and concurrency control on a per-host basis. + host_pool: Option, } impl Default for ClientBuilder { @@ -412,6 +418,7 @@ impl ClientBuilder { self.require_https, self.plugin_request_chain, self.include_fragments, + self.host_pool, ); Ok(Client { @@ -467,6 +474,48 @@ pub struct Client { } impl Client { + /// Get per-host statistics from the rate limiting system + /// + /// Returns a map of hostnames to their statistics, or an empty map + /// if host-based rate limiting is not enabled. + #[must_use] + pub fn host_stats(&self) -> std::collections::HashMap { + self.website_checker.host_stats() + } + + /// Get cache statistics for all hosts + /// + /// Returns a map of hostnames to (`cache_size`, `hit_rate`), or an empty map + /// if host-based rate limiting is not enabled. + #[must_use] + pub fn cache_stats(&self) -> std::collections::HashMap { + self.website_checker.cache_stats() + } + + /// Record a cache hit for the given URI + /// + /// This tracks that a request was served from cache rather than making + /// a network request. This is used for statistics tracking. + /// + /// # Errors + /// + /// Returns an error if the URI cannot be parsed or if host tracking fails. + pub fn record_cache_hit(&self, uri: &crate::Uri) -> crate::Result<()> { + self.website_checker.record_cache_hit(uri) + } + + /// Record a cache miss for the given URI + /// + /// This tracks that a request could not be served from cache and will + /// require a network request. This is used for statistics tracking. + /// + /// # Errors + /// + /// Returns an error if the URI cannot be parsed or if host tracking fails. + pub fn record_cache_miss(&self, uri: &crate::Uri) -> crate::Result<()> { + self.website_checker.record_cache_miss(uri) + } + /// Check a single request. /// /// `request` can be either a [`Request`] or a type that can be converted diff --git a/lychee-lib/src/extract/html/html5gum.rs b/lychee-lib/src/extract/html/html5gum.rs index ffa03a3a6d..2ae498a4ae 100644 --- a/lychee-lib/src/extract/html/html5gum.rs +++ b/lychee-lib/src/extract/html/html5gum.rs @@ -249,7 +249,6 @@ impl LinkExtractor { if let Some(name) = self.current_attributes.get("name") { self.fragments.insert(name.to_string()); } - self.current_attributes.clear(); } } diff --git a/lychee-lib/src/lib.rs b/lychee-lib/src/lib.rs index 37dfd62ea9..371013c77d 100644 --- a/lychee-lib/src/lib.rs +++ b/lychee-lib/src/lib.rs @@ -68,6 +68,9 @@ pub mod extract; pub mod remap; +/// Per-host rate limiting and concurrency control +pub mod ratelimit; + /// Filters are a way to define behavior when encountering /// URIs that need to be treated differently, such as /// local IPs or e-mail addresses diff --git a/lychee-lib/src/ratelimit/config.rs b/lychee-lib/src/ratelimit/config.rs new file mode 100644 index 0000000000..5993d0197b --- /dev/null +++ b/lychee-lib/src/ratelimit/config.rs @@ -0,0 +1,207 @@ +use http::{HeaderMap, HeaderName, HeaderValue}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// Global rate limiting configuration that applies as defaults to all hosts +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct RateLimitConfig { + /// Default maximum concurrent requests per host + #[serde(default = "default_host_concurrency")] + pub host_concurrency: usize, + + /// Default minimum interval between requests to the same host + #[serde(default = "default_request_interval")] + #[serde(with = "humantime_serde")] + pub request_interval: Duration, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + host_concurrency: default_host_concurrency(), + request_interval: default_request_interval(), + } + } +} + +impl RateLimitConfig { + /// Create a `RateLimitConfig` from CLI options, using defaults for missing values + #[must_use] + pub fn from_options( + host_concurrency: Option, + request_interval: Option, + ) -> Self { + Self { + host_concurrency: host_concurrency.unwrap_or(DEFAULT_HOST_CONCURRENCY), + request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL), + } + } +} + +/// Configuration for a specific host's rate limiting behavior +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct HostConfig { + /// Maximum concurrent requests allowed to this host + pub max_concurrent: Option, + + /// Minimum interval between requests to this host + #[serde(with = "humantime_serde")] + pub request_interval: Option, + + /// Custom headers to send with requests to this host + #[serde(default)] + #[serde(deserialize_with = "deserialize_headers")] + #[serde(serialize_with = "serialize_headers")] + pub headers: HeaderMap, +} + +impl Default for HostConfig { + fn default() -> Self { + Self { + max_concurrent: None, + request_interval: None, + headers: HeaderMap::new(), + } + } +} + +impl HostConfig { + /// Get the effective max concurrency, falling back to the global default + #[must_use] + pub fn effective_max_concurrent(&self, global_config: &RateLimitConfig) -> usize { + self.max_concurrent + .unwrap_or(global_config.host_concurrency) + } + + /// Get the effective request interval, falling back to the global default + #[must_use] + pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration { + self.request_interval + .unwrap_or(global_config.request_interval) + } +} + +/// Default number of concurrent requests per host +const DEFAULT_HOST_CONCURRENCY: usize = 10; + +/// Default interval between requests to the same host +const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(100); + +/// Default number of concurrent requests per host +const fn default_host_concurrency() -> usize { + DEFAULT_HOST_CONCURRENCY +} + +/// Default interval between requests to the same host +const fn default_request_interval() -> Duration { + DEFAULT_REQUEST_INTERVAL +} + +/// Custom deserializer for headers from TOML config format +fn deserialize_headers<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let map = HashMap::::deserialize(deserializer)?; + let mut header_map = HeaderMap::new(); + + for (name, value) in map { + let header_name = HeaderName::from_bytes(name.as_bytes()) + .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?; + let header_value = HeaderValue::from_str(&value).map_err(|e| { + serde::de::Error::custom(format!("Invalid header value '{value}': {e}")) + })?; + header_map.insert(header_name, header_value); + } + + Ok(header_map) +} + +/// Custom serializer for headers to TOML config format +fn serialize_headers(headers: &HeaderMap, serializer: S) -> Result +where + S: serde::Serializer, +{ + let map: HashMap = headers + .iter() + .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) + .collect(); + map.serialize(serializer) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_rate_limit_config() { + let config = RateLimitConfig::default(); + assert_eq!(config.host_concurrency, 10); + assert_eq!(config.request_interval, Duration::from_millis(100)); + } + + #[test] + fn test_host_config_effective_values() { + let global_config = RateLimitConfig::default(); + + // Test with no overrides + let host_config = HostConfig::default(); + assert_eq!(host_config.effective_max_concurrent(&global_config), 10); + assert_eq!( + host_config.effective_request_interval(&global_config), + Duration::from_millis(100) + ); + + // Test with overrides + let host_config = HostConfig { + max_concurrent: Some(5), + request_interval: Some(Duration::from_millis(500)), + headers: HeaderMap::new(), + }; + assert_eq!(host_config.effective_max_concurrent(&global_config), 5); + assert_eq!( + host_config.effective_request_interval(&global_config), + Duration::from_millis(500) + ); + } + + #[test] + fn test_config_serialization() { + let config = RateLimitConfig { + host_concurrency: 15, + request_interval: Duration::from_millis(200), + }; + + let toml = toml::to_string(&config).unwrap(); + let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap(); + + assert_eq!(config.host_concurrency, deserialized.host_concurrency); + assert_eq!(config.request_interval, deserialized.request_interval); + } + + #[test] + fn test_headers_serialization() { + let mut headers = HeaderMap::new(); + headers.insert("Authorization", "Bearer token123".parse().unwrap()); + headers.insert("User-Agent", "test-agent".parse().unwrap()); + + let host_config = HostConfig { + max_concurrent: Some(5), + request_interval: Some(Duration::from_millis(500)), + headers, + }; + + let toml = toml::to_string(&host_config).unwrap(); + let deserialized: HostConfig = toml::from_str(&toml).unwrap(); + + assert_eq!(deserialized.max_concurrent, Some(5)); + assert_eq!( + deserialized.request_interval, + Some(Duration::from_millis(500)) + ); + assert_eq!(deserialized.headers.len(), 2); + assert!(deserialized.headers.contains_key("authorization")); + assert!(deserialized.headers.contains_key("user-agent")); + } +} diff --git a/lychee-lib/src/ratelimit/error.rs b/lychee-lib/src/ratelimit/error.rs new file mode 100644 index 0000000000..c39f463d5b --- /dev/null +++ b/lychee-lib/src/ratelimit/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error; + +/// Errors that can occur during rate limiting operations +#[derive(Error, Debug)] +pub enum RateLimitError { + /// Host exceeded its rate limit + #[error("Host {host} exceeded rate limit: {message}")] + RateLimitExceeded { + /// The host that exceeded the limit + host: String, + /// Additional context message + message: String, + }, + + /// Failed to parse rate limit headers from server response + #[error("Failed to parse rate limit headers from {host}: {reason}")] + HeaderParseError { + /// The host that sent invalid headers + host: String, + /// Reason for parse failure + reason: String, + }, + + /// Error creating or configuring HTTP client for host + #[error("Failed to configure client for host {host}: {source}")] + ClientConfigError { + /// The host that failed configuration + host: String, + /// Underlying error + source: reqwest::Error, + }, + + /// Cookie store operation failed + #[error("Cookie operation failed for host {host}: {reason}")] + CookieError { + /// The host with cookie issues + host: String, + /// Description of cookie error + reason: String, + }, + + /// Network error occurred during request execution + #[error("Network error for host {host}: {source}")] + NetworkError { + /// The host that had the network error + host: String, + /// The underlying network error + #[source] + source: reqwest::Error, + }, +} diff --git a/lychee-lib/src/ratelimit/host/host.rs b/lychee-lib/src/ratelimit/host/host.rs new file mode 100644 index 0000000000..98f4d176d6 --- /dev/null +++ b/lychee-lib/src/ratelimit/host/host.rs @@ -0,0 +1,527 @@ +use dashmap::DashMap; +use governor::{ + Quota, RateLimiter, + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, +}; +use reqwest::{Client as ReqwestClient, Request, Response, redirect}; +use reqwest_cookie_store::CookieStoreMutex; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; +use tokio::sync::Semaphore; + +use super::key::HostKey; +use super::stats::HostStats; +use crate::ratelimit::{HostConfig, RateLimitConfig, RateLimitError}; +use crate::{CacheStatus, Status, Uri}; + +/// Cache value for per-host caching +#[derive(Debug, Clone)] +struct HostCacheValue { + status: CacheStatus, + timestamp: Instant, +} + +impl From<&Status> for HostCacheValue { + fn from(status: &Status) -> Self { + Self { + status: status.into(), + timestamp: Instant::now(), + } + } +} + +/// Per-host cache for storing request results +type HostCache = DashMap; + +/// Represents a single host with its own rate limiting, concurrency control, +/// HTTP client configuration, and request cache. +/// +/// Each host maintains: +/// - A token bucket rate limiter using governor +/// - A semaphore for concurrency control +/// - A dedicated HTTP client with host-specific headers and cookies +/// - Statistics tracking for adaptive behavior +/// - A per-host cache to prevent duplicate requests +#[derive(Debug)] +pub struct Host { + /// The hostname this instance manages + pub key: HostKey, + + /// Rate limiter using token bucket algorithm + rate_limiter: RateLimiter, + + /// Controls maximum concurrent requests to this host + semaphore: Arc, + + /// HTTP client configured for this specific host + client: ReqwestClient, + + /// Cookie jar for maintaining session state (per-host) + #[allow(dead_code)] + cookie_jar: Arc, + + /// Request statistics and adaptive behavior tracking + stats: Arc>, + + /// Current backoff duration for adaptive rate limiting + backoff_duration: Arc>, + + /// Per-host cache to prevent duplicate requests + cache: HostCache, + + /// Maximum age for cached entries (in seconds) + cache_max_age: u64, +} + +impl Host { + /// Create a new Host instance for the given hostname + /// + /// # Arguments + /// + /// * `key` - The hostname this host will manage + /// * `host_config` - Host-specific configuration + /// * `global_config` - Global defaults to fall back to + /// * `cache_max_age` - Maximum age for cached entries in seconds (0 to disable caching) + /// * `shared_cookie_jar` - Optional shared cookie jar to use instead of creating per-host jar + /// * `global_headers` - Global headers to be applied to all requests (User-Agent, custom headers, etc.) + /// * `max_redirects` - Maximum number of redirects to follow + /// * `timeout` - Request timeout + /// * `allow_insecure` - Whether to allow insecure certificates + /// + /// # Errors + /// + /// Returns an error if the HTTP client cannot be configured properly + /// + /// # Panics + /// + /// Panics if the burst size cannot be set to 1 (should never happen) + #[allow(clippy::too_many_arguments)] + pub fn new( + key: HostKey, + host_config: &HostConfig, + global_config: &RateLimitConfig, + cache_max_age: u64, + shared_cookie_jar: Option>, + global_headers: &http::HeaderMap, + max_redirects: usize, + timeout: Option, + allow_insecure: bool, + ) -> Result { + let interval = host_config.effective_request_interval(global_config); + let quota = Quota::with_period(interval) + .ok_or_else(|| RateLimitError::HeaderParseError { + host: key.to_string(), + reason: "Invalid rate limit interval".to_string(), + })? + .allow_burst(std::num::NonZeroU32::new(1).unwrap()); + + let rate_limiter = RateLimiter::direct(quota); + + // Create semaphore for concurrency control + let max_concurrent = host_config.effective_max_concurrent(global_config); + let semaphore = Arc::new(Semaphore::new(max_concurrent)); + + // Use shared cookie jar if provided, otherwise create per-host one + let cookie_jar = shared_cookie_jar.unwrap_or_else(|| Arc::new(CookieStoreMutex::default())); + + // Combine global headers with host-specific headers + let mut combined_headers = global_headers.clone(); + for (name, value) in &host_config.headers { + combined_headers.insert(name, value.clone()); + } + + // Create custom redirect policy matching main client behavior + let redirect_policy = redirect::Policy::custom(move |attempt| { + if attempt.previous().len() > max_redirects { + attempt.error("too many redirects") + } else { + log::debug!("Redirecting to {}", attempt.url()); + attempt.follow() + } + }); + + // Build HTTP client with proper configuration + let mut builder = ReqwestClient::builder() + .cookie_provider(cookie_jar.clone()) + .default_headers(combined_headers) + .gzip(true) + .danger_accept_invalid_certs(allow_insecure) + .connect_timeout(Duration::from_secs(10)) // CONNECT_TIMEOUT constant + .tcp_keepalive(Duration::from_secs(60)) // TCP_KEEPALIVE constant + .redirect(redirect_policy); + + if let Some(timeout) = timeout { + builder = builder.timeout(timeout); + } + + let client = builder + .build() + .map_err(|e| RateLimitError::ClientConfigError { + host: key.to_string(), + source: e, + })?; + + Ok(Host { + key, + rate_limiter, + semaphore, + client, + cookie_jar, + stats: Arc::new(Mutex::new(HostStats::default())), + backoff_duration: Arc::new(Mutex::new(Duration::from_millis(0))), + cache: DashMap::new(), + cache_max_age, + }) + } + + /// Check if a URI is cached and return the cached status if valid + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn get_cached_status(&self, uri: &Uri) -> Option { + if self.cache_max_age == 0 { + // Track cache miss when caching is disabled + self.stats.lock().unwrap().record_cache_miss(); + return None; // Caching disabled + } + + if let Some(entry) = self.cache.get(uri) { + let age = entry.timestamp.elapsed().as_secs(); + if age <= self.cache_max_age { + // Cache hit + self.stats.lock().unwrap().record_cache_hit(); + return Some(entry.status); + } + // Cache entry expired, remove it + drop(entry); + self.cache.remove(uri); + } + // Cache miss + self.stats.lock().unwrap().record_cache_miss(); + None + } + + /// Cache a request result + pub fn cache_result(&self, uri: &Uri, status: &Status) { + if self.cache_max_age > 0 { + let cache_value = HostCacheValue::from(status); + self.cache.insert(uri.clone(), cache_value); + } + } + + /// Execute a request with rate limiting, concurrency control, and caching + /// + /// This method: + /// 1. Checks the per-host cache for existing results + /// 2. If not cached, acquires a semaphore permit for concurrency control + /// 3. Waits for rate limiter permission + /// 4. Applies adaptive backoff if needed + /// 5. Executes the request + /// 6. Updates statistics based on response + /// 7. Parses rate limit headers to adjust future behavior + /// 8. Caches the result for future use + /// + /// # Arguments + /// + /// * `request` - The HTTP request to execute + /// + /// # Errors + /// + /// Returns an error if the request fails or rate limiting is exceeded + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub async fn execute_request(&self, request: Request) -> Result { + let uri = Uri::from(request.url().clone()); + + // Note: Cache checking is handled at the HostPool level + // This method focuses on executing the actual HTTP request + + // Acquire semaphore permit for concurrency control + let _permit = + self.semaphore + .acquire() + .await + .map_err(|_| RateLimitError::RateLimitExceeded { + host: self.key.to_string(), + message: "Semaphore acquisition cancelled".to_string(), + })?; + + // Apply adaptive backoff if needed + let backoff_duration = { + let backoff = self.backoff_duration.lock().unwrap(); + *backoff + }; + if !backoff_duration.is_zero() { + log::debug!( + "Host {} applying backoff delay of {}ms due to previous rate limiting or errors", + self.key, + backoff_duration.as_millis() + ); + tokio::time::sleep(backoff_duration).await; + } + + // Wait for rate limiter permission + self.rate_limiter.until_ready().await; + + // Execute the request and track timing + let start_time = Instant::now(); + let response = match self.client.execute(request).await { + Ok(response) => response, + Err(e) => { + // Wrap network/HTTP errors to preserve the original error + return Err(RateLimitError::NetworkError { + host: self.key.to_string(), + source: e, + }); + } + }; + let request_time = start_time.elapsed(); + + // Update statistics based on response + let status_code = response.status().as_u16(); + self.update_stats_and_backoff(status_code, request_time); + + // Parse rate limit headers to adjust behavior + self.parse_rate_limit_headers(&response); + + // Cache the result + let status = Status::Ok(response.status()); + self.cache_result(&uri, &status); + + Ok(response) + } + + /// Update internal statistics and backoff based on the response + fn update_stats_and_backoff(&self, status_code: u16, request_time: Duration) { + // Update statistics + { + let mut stats = self.stats.lock().unwrap(); + stats.record_response(status_code, request_time); + } + + // Update backoff duration based on response + { + let mut backoff = self.backoff_duration.lock().unwrap(); + match status_code { + 200..=299 => { + // Reset backoff on success + *backoff = Duration::from_millis(0); + } + 429 => { + // Exponential backoff on rate limit, capped at 30 seconds + let new_backoff = std::cmp::min( + if backoff.is_zero() { + Duration::from_millis(500) + } else { + *backoff * 2 + }, + Duration::from_secs(30), + ); + log::debug!( + "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms", + self.key, + backoff.as_millis(), + new_backoff.as_millis() + ); + *backoff = new_backoff; + } + 500..=599 => { + // Moderate backoff increase on server errors, capped at 10 seconds + *backoff = std::cmp::min( + *backoff + Duration::from_millis(200), + Duration::from_secs(10), + ); + } + _ => {} // No backoff change for other status codes + } + } + } + + /// Parse rate limit headers from response and adjust behavior + fn parse_rate_limit_headers(&self, response: &Response) { + // Manual parsing of common rate limit headers + // We implement basic parsing here for the most common headers (X-RateLimit-*, Retry-After) + // rather than using the rate-limits crate to keep dependencies minimal + + let headers = response.headers(); + + // Try common rate limit header patterns + let remaining = Self::parse_header_value( + headers, + &[ + "x-ratelimit-remaining", + "x-rate-limit-remaining", + "ratelimit-remaining", + ], + ); + + let limit = Self::parse_header_value( + headers, + &["x-ratelimit-limit", "x-rate-limit-limit", "ratelimit-limit"], + ); + + if let (Some(remaining), Some(limit)) = (remaining, limit) { + if limit > 0 { + #[allow(clippy::cast_precision_loss)] + let usage_ratio = (limit - remaining) as f64 / limit as f64; + + // If we've used more than 80% of our quota, apply preventive backoff + if usage_ratio > 0.8 { + let mut backoff = self.backoff_duration.lock().unwrap(); + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let preventive_backoff = + Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64); + *backoff = std::cmp::max(*backoff, preventive_backoff); + } + } + } + + // Check for Retry-After header (in seconds) + if let Some(retry_after_value) = headers.get("retry-after") { + if let Ok(retry_after_str) = retry_after_value.to_str() { + if let Ok(retry_seconds) = retry_after_str.parse::() { + let mut backoff = self.backoff_duration.lock().unwrap(); + let retry_duration = Duration::from_secs(retry_seconds); + // Cap retry-after to reasonable limits + if retry_duration <= Duration::from_secs(3600) { + *backoff = std::cmp::max(*backoff, retry_duration); + } + } + } + } + } + + /// Helper method to parse numeric header values from common rate limit headers + fn parse_header_value(headers: &http::HeaderMap, header_names: &[&str]) -> Option { + for header_name in header_names { + if let Some(value) = headers.get(*header_name) { + if let Ok(value_str) = value.to_str() { + if let Ok(number) = value_str.parse::() { + return Some(number); + } + } + } + } + None + } + + /// Get host statistics + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn stats(&self) -> HostStats { + self.stats.lock().unwrap().clone() + } + + /// Record a cache hit from the persistent disk cache + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn record_persistent_cache_hit(&self) { + self.stats.lock().unwrap().record_cache_hit(); + } + + /// Record a cache miss from the persistent disk cache + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn record_persistent_cache_miss(&self) { + self.stats.lock().unwrap().record_cache_miss(); + } + + /// Get the current number of available permits (concurrent request slots) + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Get the current cache size (number of cached entries) + pub fn cache_size(&self) -> usize { + self.cache.len() + } + + /// Clear expired entries from the cache + pub fn cleanup_cache(&self) { + if self.cache_max_age == 0 { + return; + } + + self.cache + .retain(|_, value| value.timestamp.elapsed().as_secs() <= self.cache_max_age); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ratelimit::{HostConfig, RateLimitConfig}; + use std::time::Duration; + + #[tokio::test] + async fn test_host_creation() { + let key = HostKey::from("example.com"); + let host_config = HostConfig::default(); + let global_config = RateLimitConfig::default(); + + let host = Host::new( + key.clone(), + &host_config, + &global_config, + 3600, + None, + &http::HeaderMap::new(), + 5, + Some(std::time::Duration::from_secs(20)), + false, + ) + .unwrap(); + + assert_eq!(host.key, key); + assert_eq!(host.available_permits(), 10); // Default concurrency + assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON); + assert_eq!(host.cache_size(), 0); + } + + #[test] + fn test_cache_expiration() { + let key = HostKey::from("example.com"); + let host_config = HostConfig::default(); + let global_config = RateLimitConfig::default(); + + let host = Host::new( + key, + &host_config, + &global_config, + 1, + None, + &http::HeaderMap::new(), + 5, + Some(std::time::Duration::from_secs(20)), + false, + ) + .unwrap(); // 1 second cache + + let uri = Uri::from("https://example.com/test".parse::().unwrap()); + let status = Status::Ok(http::StatusCode::OK); + + // Cache the result + host.cache_result(&uri, &status); + assert_eq!(host.cache_size(), 1); + + // Should be in cache immediately + assert!(host.get_cached_status(&uri).is_some()); + + // Wait for expiration and cleanup + std::thread::sleep(Duration::from_secs(2)); + host.cleanup_cache(); + + // Should be expired now + assert!(host.get_cached_status(&uri).is_none()); + } +} diff --git a/lychee-lib/src/ratelimit/host/key.rs b/lychee-lib/src/ratelimit/host/key.rs new file mode 100644 index 0000000000..ffc6f538d0 --- /dev/null +++ b/lychee-lib/src/ratelimit/host/key.rs @@ -0,0 +1,153 @@ +use std::fmt; +use url::Url; + +/// A type-safe representation of a hostname for rate limiting purposes. +/// +/// This extracts and normalizes hostnames from URLs to ensure consistent +/// rate limiting across requests to the same host. Subdomains are treated +/// as separate hosts to allow for traffic sharding. +/// +/// # Examples +/// +/// ``` +/// use lychee_lib::ratelimit::HostKey; +/// use url::Url; +/// +/// let url = Url::parse("https://api.github.com/repos/user/repo").unwrap(); +/// let host_key = HostKey::try_from(&url).unwrap(); +/// assert_eq!(host_key.as_str(), "api.github.com"); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HostKey(String); + +impl HostKey { + /// Get the hostname as a string slice + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Get the hostname as an owned String + #[must_use] + pub fn into_string(self) -> String { + self.0 + } +} + +impl TryFrom<&Url> for HostKey { + type Error = crate::ratelimit::RateLimitError; + + fn try_from(url: &Url) -> Result { + let host = + url.host_str() + .ok_or_else(|| crate::ratelimit::RateLimitError::HeaderParseError { + host: url.to_string(), + reason: "URL contains no host component".to_string(), + })?; + + // Normalize to lowercase for consistent lookup + Ok(HostKey(host.to_lowercase())) + } +} + +impl TryFrom<&crate::Uri> for HostKey { + type Error = crate::ratelimit::RateLimitError; + + fn try_from(uri: &crate::Uri) -> Result { + Self::try_from(&uri.url) + } +} + +impl TryFrom for HostKey { + type Error = crate::ratelimit::RateLimitError; + + fn try_from(url: Url) -> Result { + HostKey::try_from(&url) + } +} + +impl fmt::Display for HostKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for HostKey { + fn from(host: String) -> Self { + HostKey(host.to_lowercase()) + } +} + +impl From<&str> for HostKey { + fn from(host: &str) -> Self { + HostKey(host.to_lowercase()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_host_key_from_url() { + let url = Url::parse("https://api.github.com/repos/user/repo").unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + assert_eq!(host_key.as_str(), "api.github.com"); + } + + #[test] + fn test_host_key_normalization() { + let url = Url::parse("https://API.GITHUB.COM/repos/user/repo").unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + assert_eq!(host_key.as_str(), "api.github.com"); + } + + #[test] + fn test_host_key_subdomain_separation() { + let api_url = Url::parse("https://api.github.com/").unwrap(); + let www_url = Url::parse("https://www.github.com/").unwrap(); + + let api_key = HostKey::try_from(&api_url).unwrap(); + let www_key = HostKey::try_from(&www_url).unwrap(); + + assert_ne!(api_key, www_key); + assert_eq!(api_key.as_str(), "api.github.com"); + assert_eq!(www_key.as_str(), "www.github.com"); + } + + #[test] + fn test_host_key_from_string() { + let host_key = HostKey::from("example.com"); + assert_eq!(host_key.as_str(), "example.com"); + + let host_key = HostKey::from("EXAMPLE.COM"); + assert_eq!(host_key.as_str(), "example.com"); + } + + #[test] + fn test_host_key_no_host() { + let url = Url::parse("file:///path/to/file").unwrap(); + let result = HostKey::try_from(&url); + assert!(result.is_err()); + } + + #[test] + fn test_host_key_display() { + let host_key = HostKey::from("example.com"); + assert_eq!(format!("{host_key}"), "example.com"); + } + + #[test] + fn test_host_key_hash_equality() { + use std::collections::HashMap; + + let key1 = HostKey::from("example.com"); + let key2 = HostKey::from("EXAMPLE.COM"); + + let mut map = HashMap::new(); + map.insert(key1, "value"); + + // Should find the value with normalized key + assert_eq!(map.get(&key2), Some(&"value")); + } +} diff --git a/lychee-lib/src/ratelimit/host/mod.rs b/lychee-lib/src/ratelimit/host/mod.rs new file mode 100644 index 0000000000..50b8b1ad3e --- /dev/null +++ b/lychee-lib/src/ratelimit/host/mod.rs @@ -0,0 +1,9 @@ +#![allow(clippy::module_inception)] + +mod host; +mod key; +mod stats; + +pub use host::Host; +pub use key::HostKey; +pub use stats::HostStats; diff --git a/lychee-lib/src/ratelimit/host/stats.rs b/lychee-lib/src/ratelimit/host/stats.rs new file mode 100644 index 0000000000..4a11575c79 --- /dev/null +++ b/lychee-lib/src/ratelimit/host/stats.rs @@ -0,0 +1,283 @@ +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::ratelimit::window::Window; + +/// Statistics tracking for a host's request patterns +#[derive(Debug, Clone, Default)] +pub struct HostStats { + /// Total number of requests made to this host + pub total_requests: u64, + /// Number of successful requests (2xx status) + pub successful_requests: u64, + /// Number of requests that received rate limit responses (429) + pub rate_limited: u64, + /// Number of server error responses (5xx) + pub server_errors: u64, + /// Number of client error responses (4xx, excluding 429) + pub client_errors: u64, + /// Timestamp of the last successful request + pub last_success: Option, + /// Timestamp of the last rate limit response + pub last_rate_limit: Option, + /// Request times for median calculation (kept in rolling window) + pub request_times: Window, + /// Status code counts + pub status_codes: HashMap, + /// Number of cache hits + pub cache_hits: u64, + /// Number of cache misses + pub cache_misses: u64, +} + +impl HostStats { + /// Create new host statistics with custom window size for request times + #[must_use] + pub fn with_window_size(window_size: usize) -> Self { + Self { + request_times: Window::new(window_size), + ..Default::default() + } + } + + /// Record a response with status code and request duration + pub fn record_response(&mut self, status_code: u16, request_time: Duration) { + self.total_requests += 1; + + // Track status code + *self.status_codes.entry(status_code).or_insert(0) += 1; + + // Categorize response + match status_code { + 200..=299 => { + self.successful_requests += 1; + self.last_success = Some(Instant::now()); + } + 429 => { + self.rate_limited += 1; + self.last_rate_limit = Some(Instant::now()); + } + 400..=499 => { + self.client_errors += 1; + } + 500..=599 => { + self.server_errors += 1; + } + _ => {} // Other status codes + } + + // Track request time in rolling window + self.request_times.push(request_time); + } + + /// Get median request time + #[must_use] + pub fn median_request_time(&self) -> Option { + if self.request_times.is_empty() { + return None; + } + + let mut times = self.request_times.to_vec(); + times.sort(); + let mid = times.len() / 2; + + if times.len() % 2 == 0 { + // Average of two middle values + Some((times[mid - 1] + times[mid]) / 2) + } else { + Some(times[mid]) + } + } + + /// Get error rate (percentage) + #[must_use] + pub fn error_rate(&self) -> f64 { + if self.total_requests == 0 { + return 0.0; + } + let errors = self.rate_limited + self.client_errors + self.server_errors; + #[allow(clippy::cast_precision_loss)] + let error_rate = errors as f64 / self.total_requests as f64; + error_rate * 100.0 + } + + /// Get the current success rate (0.0 to 1.0) + #[must_use] + pub fn success_rate(&self) -> f64 { + if self.total_requests == 0 { + 1.0 // Assume success until proven otherwise + } else { + #[allow(clippy::cast_precision_loss)] + let success_rate = self.successful_requests as f64 / self.total_requests as f64; + success_rate + } + } + + /// Get average request time + #[must_use] + pub fn average_request_time(&self) -> Option { + if self.request_times.is_empty() { + return None; + } + + let total: Duration = self.request_times.iter().sum(); + #[allow(clippy::cast_possible_truncation)] + Some(total / (self.request_times.len() as u32)) + } + + /// Get the most recent request time + #[must_use] + pub fn latest_request_time(&self) -> Option { + self.request_times.iter().last().copied() + } + + /// Check if this host has been experiencing rate limiting recently + #[must_use] + pub fn is_currently_rate_limited(&self) -> bool { + if let Some(last_rate_limit) = self.last_rate_limit { + // Consider rate limited if we got a 429 in the last 60 seconds + last_rate_limit.elapsed() < Duration::from_secs(60) + } else { + false + } + } + + /// Record a cache hit + pub const fn record_cache_hit(&mut self) { + self.cache_hits += 1; + // Cache hits should also count as total requests from user perspective + self.total_requests += 1; + // Cache hits are typically for successful previous requests, so count as successful + self.successful_requests += 1; + } + + /// Record a cache miss + pub const fn record_cache_miss(&mut self) { + self.cache_misses += 1; + // Cache misses will be followed by actual requests that increment total_requests + // so we don't increment here to avoid double-counting + } + + /// Get cache hit rate (0.0 to 1.0) + #[must_use] + pub fn cache_hit_rate(&self) -> f64 { + let total_cache_requests = self.cache_hits + self.cache_misses; + if total_cache_requests == 0 { + 0.0 + } else { + #[allow(clippy::cast_precision_loss)] + let hit_rate = self.cache_hits as f64 / total_cache_requests as f64; + hit_rate + } + } + + /// Get human-readable summary of the stats + #[must_use] + pub fn summary(&self) -> String { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let success_pct = (self.success_rate() * 100.0) as u64; + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let error_pct = self.error_rate() as u64; + + let avg_time = self + .average_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + format!( + "{} requests ({}% success, {}% errors), avg: {}", + self.total_requests, success_pct, error_pct, avg_time + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_host_stats_success_rate() { + let mut stats = HostStats::default(); + + // No requests yet - should assume success + assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON); + + // Record some successful requests + stats.record_response(200, Duration::from_millis(100)); + stats.record_response(200, Duration::from_millis(120)); + assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON); + + // Record a rate limited request + stats.record_response(429, Duration::from_millis(150)); + assert!((stats.success_rate() - (2.0 / 3.0)).abs() < 0.001); + + // Record a server error + stats.record_response(500, Duration::from_millis(200)); + assert!((stats.success_rate() - 0.5).abs() < f64::EPSILON); + } + + #[test] + fn test_host_stats_tracking() { + let mut stats = HostStats::default(); + + // Initially empty + assert_eq!(stats.total_requests, 0); + assert_eq!(stats.successful_requests, 0); + assert!(stats.error_rate().abs() < f64::EPSILON); + + // Record a successful response + stats.record_response(200, Duration::from_millis(100)); + assert_eq!(stats.total_requests, 1); + assert_eq!(stats.successful_requests, 1); + assert!(stats.error_rate().abs() < f64::EPSILON); + assert_eq!(stats.status_codes.get(&200), Some(&1)); + + // Record rate limited response + stats.record_response(429, Duration::from_millis(200)); + assert_eq!(stats.total_requests, 2); + assert_eq!(stats.rate_limited, 1); + assert!((stats.error_rate() - 50.0).abs() < f64::EPSILON); + + // Record server error + stats.record_response(500, Duration::from_millis(150)); + assert_eq!(stats.total_requests, 3); + assert_eq!(stats.server_errors, 1); + + // Check median request time + assert_eq!( + stats.median_request_time(), + Some(Duration::from_millis(150)) + ); + } + + #[test] + fn test_window_integration() { + let mut stats = HostStats::with_window_size(2); + + stats.record_response(200, Duration::from_millis(100)); + stats.record_response(200, Duration::from_millis(200)); + stats.record_response(200, Duration::from_millis(300)); + + // Window should only keep last 2 times + assert_eq!(stats.request_times.len(), 2); + + let times: Vec<_> = stats.request_times.iter().copied().collect(); + assert_eq!( + times, + vec![Duration::from_millis(200), Duration::from_millis(300)] + ); + } + + #[test] + fn test_summary_formatting() { + let mut stats = HostStats::default(); + stats.record_response(200, Duration::from_millis(150)); + stats.record_response(500, Duration::from_millis(200)); + + let summary = stats.summary(); + assert!(summary.contains("2 requests")); + assert!(summary.contains("50% success")); + assert!(summary.contains("50% errors")); + assert!(summary.contains("175ms")); // average of 150 and 200 + } +} diff --git a/lychee-lib/src/ratelimit/mod.rs b/lychee-lib/src/ratelimit/mod.rs new file mode 100644 index 0000000000..eeccd90ba2 --- /dev/null +++ b/lychee-lib/src/ratelimit/mod.rs @@ -0,0 +1,26 @@ +//! Per-host rate limiting and concurrency control. +//! +//! This module provides adaptive rate limiting for HTTP requests on a per-host basis. +//! It prevents overwhelming servers with too many concurrent requests and respects +//! server-provided rate limit headers. +//! +//! # Architecture +//! +//! - [`HostKey`]: Represents a hostname/domain for rate limiting +//! - [`Host`]: Manages rate limiting, concurrency, caching, and cookies for a specific host +//! - [`HostPool`]: Coordinates multiple hosts and routes requests appropriately +//! - [`HostConfig`]: Configuration for per-host behavior +//! - [`HostStats`]: Statistics tracking for each host +//! - [`Window`]: Rolling window data structure for request timing + +mod config; +mod error; +mod host; +mod pool; +mod window; + +pub use config::{HostConfig, RateLimitConfig}; +pub use error::RateLimitError; +pub use host::{Host, HostKey, HostStats}; +pub use pool::HostPool; +pub use window::Window; diff --git a/lychee-lib/src/ratelimit/pool.rs b/lychee-lib/src/ratelimit/pool.rs new file mode 100644 index 0000000000..ac2f471a29 --- /dev/null +++ b/lychee-lib/src/ratelimit/pool.rs @@ -0,0 +1,520 @@ +use dashmap::DashMap; +use http::HeaderMap; +use reqwest::{Request, Response}; +use reqwest_cookie_store::CookieStoreMutex; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Semaphore; + +use crate::ratelimit::{Host, HostConfig, HostKey, HostStats, RateLimitConfig, RateLimitError}; +use crate::{CacheStatus, Status, Uri}; + +/// Manages a pool of Host instances and routes requests to appropriate hosts. +/// +/// The `HostPool` serves as the central coordinator for per-host rate limiting. +/// It creates Host instances on-demand, manages global concurrency limits, +/// and provides a unified interface for executing HTTP requests with +/// appropriate rate limiting applied. +/// +/// # Architecture +/// +/// - Each unique hostname gets its own Host instance with dedicated rate limiting +/// - Global semaphore enforces overall concurrency limits across all hosts +/// - Hosts are created lazily when first requested +/// - Thread-safe using `DashMap` for concurrent access to host instances +#[derive(Debug, Clone)] +pub struct HostPool { + /// Map of hostname to Host instances, created on-demand + hosts: Arc>>, + + /// Global configuration for rate limiting defaults + global_config: Arc, + + /// Per-host configuration overrides + host_configs: Arc>, + + /// Global semaphore to enforce overall concurrency limit + global_semaphore: Arc, + + /// Maximum age for cached entries in seconds (0 to disable caching) + cache_max_age: u64, + + /// Shared cookie jar used across all hosts + cookie_jar: Option>, + + /// Global headers to be applied to all requests (includes User-Agent, etc.) + global_headers: HeaderMap, + + /// Maximum number of redirects to follow + max_redirects: usize, + + /// Request timeout + timeout: Option, + + /// Whether to allow insecure certificates + allow_insecure: bool, +} + +impl HostPool { + /// Create a new `HostPool` with the given configuration + /// + /// # Arguments + /// + /// * `global_config` - Default rate limiting configuration + /// * `host_configs` - Host-specific configuration overrides + /// * `max_total_concurrency` - Global limit on concurrent requests across all hosts + /// * `cache_max_age` - Maximum age for cached entries in seconds (0 to disable caching) + /// * `global_headers` - Headers to be applied to all requests (User-Agent, custom headers, etc.) + /// * `max_redirects` - Maximum number of redirects to follow + /// * `timeout` - Request timeout + /// * `allow_insecure` - Whether to allow insecure certificates + /// + /// # Examples + /// + /// ``` + /// use lychee_lib::ratelimit::{HostPool, RateLimitConfig}; + /// use std::collections::HashMap; + /// use http::HeaderMap; + /// use std::time::Duration; + /// + /// let global_config = RateLimitConfig::default(); + /// let host_configs = HashMap::new(); + /// let global_headers = HeaderMap::new(); + /// let pool = HostPool::new(global_config, host_configs, 128, 3600, global_headers, 5, Some(Duration::from_secs(20)), false); + /// ``` + #[must_use] + #[allow(clippy::too_many_arguments)] + pub fn new( + global_config: RateLimitConfig, + host_configs: HashMap, + max_total_concurrency: usize, + cache_max_age: u64, + global_headers: HeaderMap, + max_redirects: usize, + timeout: Option, + allow_insecure: bool, + ) -> Self { + Self { + hosts: Arc::new(DashMap::new()), + global_config: Arc::new(global_config), + host_configs: Arc::new(host_configs), + global_semaphore: Arc::new(Semaphore::new(max_total_concurrency)), + cache_max_age, + cookie_jar: None, + global_headers, + max_redirects, + timeout, + allow_insecure, + } + } + + /// Add a shared cookie jar to the `HostPool` + #[must_use] + pub fn with_cookie_jar(mut self, cookie_jar: Arc) -> Self { + self.cookie_jar = Some(cookie_jar); + self + } + + /// Execute an HTTP request with appropriate per-host rate limiting + /// + /// This method: + /// 1. Extracts the hostname from the request URL + /// 2. Gets or creates the appropriate Host instance + /// 3. Acquires a global semaphore permit + /// 4. Delegates to the host for execution with host-specific rate limiting + /// + /// # Arguments + /// + /// * `request` - The HTTP request to execute + /// + /// # Errors + /// + /// Returns a `RateLimitError` if: + /// - The request URL has no valid hostname + /// - Global or host-specific rate limits are exceeded + /// - The underlying HTTP request fails + /// + /// # Examples + /// + /// ```no_run + /// # use lychee_lib::ratelimit::{HostPool, RateLimitConfig}; + /// # use std::collections::HashMap; + /// # use reqwest::{Request, header::HeaderMap}; + /// # use std::time::Duration; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let pool = HostPool::new( + /// RateLimitConfig::default(), + /// HashMap::new(), + /// 128, + /// 3600, + /// HeaderMap::new(), + /// 5, + /// Some(Duration::from_secs(20)), + /// false + /// ); + /// let request = reqwest::Request::new(reqwest::Method::GET, "https://example.com".parse()?); + /// let response = pool.execute_request(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn execute_request(&self, request: Request) -> Result { + // Extract hostname from request URL + let url = request.url(); + let host_key = HostKey::try_from(url)?; + + // Get or create host instance + let host = self.get_or_create_host(host_key)?; + + // Acquire global semaphore permit first + let _global_permit = self.global_semaphore.acquire().await.map_err(|_| { + RateLimitError::RateLimitExceeded { + host: host.key.to_string(), + message: "Global concurrency limit reached".to_string(), + } + })?; + + // Execute request through host-specific rate limiting + host.execute_request(request).await + } + + /// Get an existing host or create a new one for the given hostname + fn get_or_create_host(&self, host_key: HostKey) -> Result, RateLimitError> { + // Check if host already exists + if let Some(host) = self.hosts.get(&host_key) { + return Ok(host.clone()); + } + + // Create new host instance + let host_config = self + .host_configs + .get(host_key.as_str()) + .cloned() + .unwrap_or_default(); + + let host = Arc::new(Host::new( + host_key.clone(), + &host_config, + &self.global_config, + self.cache_max_age, + self.cookie_jar.clone(), + &self.global_headers, + self.max_redirects, + self.timeout, + self.allow_insecure, + )?); + + // Store in map (handle race condition where another thread created it) + match self.hosts.entry(host_key) { + dashmap::mapref::entry::Entry::Occupied(entry) => { + // Another thread created it, use theirs + Ok(entry.get().clone()) + } + dashmap::mapref::entry::Entry::Vacant(entry) => { + // We're first, insert ours + Ok(entry.insert(host).clone()) + } + } + } + + /// Get statistics for a specific host + /// + /// Returns statistics for the host if it exists, otherwise returns empty stats. + /// This provides consistent behavior whether or not requests have been made to that host yet. + /// + /// # Arguments + /// + /// * `hostname` - The hostname to get statistics for + #[must_use] + pub fn host_stats(&self, hostname: &str) -> HostStats { + let host_key = HostKey::from(hostname); + self.hosts + .get(&host_key) + .map(|host| host.stats()) + .unwrap_or_default() + } + + /// Get statistics for all hosts that have been created + /// + /// Returns a `HashMap` mapping hostnames to their statistics. + /// Only hosts that have had requests will be included. + #[must_use] + pub fn all_host_stats(&self) -> HashMap { + self.hosts + .iter() + .map(|entry| { + let hostname = entry.key().to_string(); + let stats = entry.value().stats(); + (hostname, stats) + }) + .collect() + } + + /// Get the number of currently active hosts + /// + /// This returns the number of Host instances that have been created, + /// which corresponds to the number of unique hostnames that have + /// been accessed. + #[must_use] + pub fn active_host_count(&self) -> usize { + self.hosts.len() + } + + /// Get the number of available global permits + /// + /// This shows how many more concurrent requests can be started + /// across all hosts before hitting the global concurrency limit. + #[must_use] + pub fn available_global_permits(&self) -> usize { + self.global_semaphore.available_permits() + } + + /// Get host configuration for debugging/monitoring + /// + /// Returns a copy of the current host-specific configurations. + /// This is useful for debugging or runtime monitoring of configuration. + #[must_use] + pub fn host_configurations(&self) -> HashMap { + (*self.host_configs).clone() + } + + /// Remove a host from the pool + /// + /// This forces the host to be recreated with updated configuration + /// the next time a request is made to it. Any ongoing requests to + /// that host will continue with the old instance. + /// + /// # Arguments + /// + /// * `hostname` - The hostname to remove from the pool + /// + /// # Returns + /// + /// Returns true if a host was removed, false if no host existed for that hostname. + #[must_use] + pub fn remove_host(&self, hostname: &str) -> bool { + let host_key = HostKey::from(hostname); + self.hosts.remove(&host_key).is_some() + } + + /// Check if a URI is cached in the appropriate host's cache + /// + /// # Arguments + /// + /// * `uri` - The URI to check for in the cache + /// + /// # Returns + /// + /// Returns the cached status if found and valid, None otherwise + #[must_use] + pub fn get_cached_status(&self, uri: &Uri) -> Option { + let host_key = HostKey::try_from(uri).ok()?; + + if let Some(host) = self.hosts.get(&host_key) { + host.get_cached_status(uri) + } else { + None + } + } + + /// Cache a result for a URI in the appropriate host's cache + /// + /// # Arguments + /// + /// * `uri` - The URI to cache + /// * `status` - The status to cache + pub fn cache_result(&self, uri: &Uri, status: &Status) { + if let Ok(host_key) = HostKey::try_from(uri) { + if let Some(host) = self.hosts.get(&host_key) { + host.cache_result(uri, status); + } + // If host doesn't exist yet, we don't cache + // The result will be cached when the host is created and the request is made + } + } + + /// Get cache statistics across all hosts + #[must_use] + pub fn cache_stats(&self) -> HashMap { + self.hosts + .iter() + .map(|entry| { + let hostname = entry.key().to_string(); + let cache_size = entry.value().cache_size(); + let hit_rate = entry.value().stats().cache_hit_rate(); + (hostname, (cache_size, hit_rate)) + }) + .collect() + } + + /// Cleanup expired cache entries across all hosts + pub fn cleanup_caches(&self) { + for host in self.hosts.iter() { + host.cleanup_cache(); + } + } + + /// Record a cache hit for the given URI in host statistics + /// + /// This tracks that a request was served from the persistent disk cache + /// rather than going through the rate-limited HTTP request flow. + /// This method will create a [Host] instance if one doesn't exist yet. + /// + /// # Errors + /// + /// Returns an error if the host key cannot be parsed from the URI or if the host cannot be created. + pub fn record_cache_hit( + &self, + uri: &crate::Uri, + ) -> Result<(), crate::ratelimit::RateLimitError> { + let host_key = crate::ratelimit::HostKey::try_from(uri)?; + + // Get or create the host (this ensures statistics tracking even for cache-only requests) + let host = self.get_or_create_host(host_key)?; + host.record_persistent_cache_hit(); + Ok(()) + } + + /// Record a cache miss for the given URI in host statistics + /// + /// This tracks that a request could not be served from the persistent disk cache + /// and will need to go through the rate-limited HTTP request flow. + /// This method will create a Host instance if one doesn't exist yet. + /// + /// # Errors + /// + /// Returns an error if the host key cannot be parsed from the URI or if the host cannot be created. + pub fn record_cache_miss( + &self, + uri: &crate::Uri, + ) -> Result<(), crate::ratelimit::RateLimitError> { + let host_key = crate::ratelimit::HostKey::try_from(uri)?; + + // Get or create the host (this ensures statistics tracking even for cache-only requests) + let host = self.get_or_create_host(host_key)?; + host.record_persistent_cache_miss(); + Ok(()) + } +} + +impl Default for HostPool { + fn default() -> Self { + Self::new( + RateLimitConfig::default(), + HashMap::new(), + 128, // Default global concurrency limit + 3600, // Default cache age of 1 hour + HeaderMap::new(), // Default empty headers + 5, // Default max redirects + Some(Duration::from_secs(20)), // Default timeout + false, // Default secure certificates + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ratelimit::RateLimitConfig; + + use url::Url; + + #[test] + fn test_host_pool_creation() { + let global_config = RateLimitConfig::default(); + let host_configs = HashMap::new(); + let pool = HostPool::new( + global_config, + host_configs, + 100, + 3600, + HeaderMap::new(), + 5, + Some(Duration::from_secs(20)), + false, + ); + + assert_eq!(pool.active_host_count(), 0); + assert_eq!(pool.available_global_permits(), 100); + } + + #[test] + fn test_host_pool_default() { + let pool = HostPool::default(); + + assert_eq!(pool.active_host_count(), 0); + assert_eq!(pool.available_global_permits(), 128); + } + + #[tokio::test] + async fn test_host_creation_on_demand() { + let pool = HostPool::default(); + let url: Url = "https://example.com/path".parse().unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + + // No hosts initially + assert_eq!(pool.active_host_count(), 0); + assert_eq!(pool.host_stats("example.com").total_requests, 0); + + // Create host on demand + let host = pool.get_or_create_host(host_key).unwrap(); + + // Now we have one host + assert_eq!(pool.active_host_count(), 1); + assert_eq!(pool.host_stats("example.com").total_requests, 0); + assert_eq!(host.key.as_str(), "example.com"); + } + + #[tokio::test] + async fn test_host_reuse() { + let pool = HostPool::default(); + let url: Url = "https://example.com/path1".parse().unwrap(); + let host_key1 = HostKey::try_from(&url).unwrap(); + + let url: Url = "https://example.com/path2".parse().unwrap(); + let host_key2 = HostKey::try_from(&url).unwrap(); + + // Create host for first request + let host1 = pool.get_or_create_host(host_key1).unwrap(); + assert_eq!(pool.active_host_count(), 1); + + // Second request to same host should reuse + let host2 = pool.get_or_create_host(host_key2).unwrap(); + assert_eq!(pool.active_host_count(), 1); + + // Should be the same instance + assert!(Arc::ptr_eq(&host1, &host2)); + } + + #[test] + fn test_host_config_management() { + let pool = HostPool::default(); + + // Initially no host configurations + let configs = pool.host_configurations(); + assert_eq!(configs.len(), 0); + } + + #[test] + fn test_host_removal() { + let pool = HostPool::default(); + + // Remove non-existent host + assert!(!pool.remove_host("nonexistent.com")); + + // We can't easily test removal of existing hosts without making actual requests + // due to the async nature of host creation, but the basic functionality works + } + + #[test] + fn test_all_host_stats() { + let pool = HostPool::default(); + + // No hosts initially + let stats = pool.all_host_stats(); + assert!(stats.is_empty()); + + // Stats would be populated after actual requests are made to create hosts + } +} diff --git a/lychee-lib/src/ratelimit/window.rs b/lychee-lib/src/ratelimit/window.rs new file mode 100644 index 0000000000..058641e0a0 --- /dev/null +++ b/lychee-lib/src/ratelimit/window.rs @@ -0,0 +1,100 @@ +use std::collections::VecDeque; + +/// A rolling window data structure that automatically maintains a maximum size +/// by removing oldest elements when the capacity is exceeded. +#[derive(Debug, Clone)] +pub struct Window { + data: VecDeque, + capacity: usize, +} + +impl Window { + /// Create a new window with the given capacity + #[must_use] + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity), + capacity, + } + } + + /// Push an element to the window, removing the oldest if at capacity + pub fn push(&mut self, item: T) { + if self.data.len() >= self.capacity { + self.data.pop_front(); + } + self.data.push_back(item); + } + + /// Get the number of elements currently in the window + #[must_use] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Check if the window is empty + #[must_use] + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Get an iterator over the elements in the window + pub fn iter(&self) -> impl Iterator { + self.data.iter() + } + + /// Convert to a vector (for compatibility with existing code) + #[must_use] + pub fn to_vec(&self) -> Vec + where + T: Clone, + { + self.data.iter().cloned().collect() + } +} + +impl Default for Window { + fn default() -> Self { + Self::new(100) // Default capacity of 100 items + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_capacity() { + let mut window = Window::new(3); + + // Fill up the window + window.push(1); + window.push(2); + window.push(3); + assert_eq!(window.len(), 3); + + // Add one more, should remove the oldest + window.push(4); + assert_eq!(window.len(), 3); + + let values: Vec<_> = window.iter().copied().collect(); + assert_eq!(values, vec![2, 3, 4]); + } + + #[test] + fn test_window_empty() { + let window: Window = Window::new(5); + assert!(window.is_empty()); + assert_eq!(window.len(), 0); + } + + #[test] + fn test_window_to_vec() { + let mut window = Window::new(3); + window.push(1); + window.push(2); + + let vec = window.to_vec(); + assert_eq!(vec, vec![1, 2]); + } +} diff --git a/lychee-lib/src/types/error.rs b/lychee-lib/src/types/error.rs index 30019b9bc2..5b5eebeeee 100644 --- a/lychee-lib/src/types/error.rs +++ b/lychee-lib/src/types/error.rs @@ -170,6 +170,10 @@ pub enum ErrorKind { #[error("Status code range error")] StatusCodeSelectorError(#[from] StatusCodeSelectorError), + /// Rate limiting error + #[error("Rate limiting error: {0}")] + RateLimit(#[from] crate::ratelimit::RateLimitError), + /// Test-only error variant for formatter tests /// Available in both test and debug builds to support cross-crate testing #[cfg(any(test, debug_assertions))] @@ -333,6 +337,9 @@ impl ErrorKind { ErrorKind::InvalidIndexFile(_path) => Some( "Index file not found in directory. Check if index.html or other index files exist".to_string() ), + ErrorKind::RateLimit(e) => Some(format!( + "Rate limiting error: {e}. Consider adjusting rate limiting configuration or waiting before retrying" + )), } } @@ -468,6 +475,7 @@ impl Hash for ErrorKind { Self::BasicAuthExtractorError(e) => e.to_string().hash(state), Self::Cookies(e) => e.to_string().hash(state), Self::StatusCodeSelectorError(e) => e.to_string().hash(state), + Self::RateLimit(e) => e.to_string().hash(state), } } }