Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions lychee-bin/tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,36 @@ The config file should contain every possible key for documentation purposes."
Ok(())
}

#[tokio::test]
async fn test_no_duplicate_requests() {
let server = wiremock::MockServer::start().await;
let count = 100; // given 100 duplicate URLs
let cached = "99.0%"; // we expect 99 out of 100 to be cached

wiremock::Mock::given(wiremock::matchers::method("GET"))
.respond_with(|_: &_| {
// Simulate real-world delay.
// Keep the delay to prove how we make use of synchronization
// primitives to prevent duplicate requests.
std::thread::sleep(std::time::Duration::from_secs(1));
ResponseTemplate::new(200)
})
.expect(1)
.mount(&server)
.await;

cargo_bin_cmd!()
.write_stdin(format!("{} ", server.uri()).repeat(count))
.arg("-")
.arg("--host-stats")
// the request interval must not have an affect on duplicates
.arg("--host-request-interval=1s")
.assert()
.success()
.stdout(contains("100.0% success"))
.stdout(contains(format!("{cached} cached")));
}

#[tokio::test]
async fn test_process_internal_host_caching() -> Result<()> {
// Note that this process-internal per-host caching
Expand Down
34 changes: 25 additions & 9 deletions lychee-lib/src/ratelimit/host/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ use http::StatusCode;
use humantime_serde::re::humantime::format_duration;
use log::warn;
use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
use std::time::{Duration, Instant};
use std::{num::NonZeroU32, sync::Mutex};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Semaphore;

use super::key::HostKey;
Expand Down Expand Up @@ -63,6 +66,9 @@ pub struct Host {
/// Per-host cache to prevent duplicate requests during a single link check invocation.
/// Note that this cache has no direct relation to the inter-process persistable [`crate::CacheStatus`].
cache: HostCache,

/// Keep track of currently active requests, to prevent duplicate concurrent requests
active_requests: DashMap<Uri, Arc<tokio::sync::Mutex<()>>>,
}

impl Host {
Expand Down Expand Up @@ -91,6 +97,7 @@ impl Host {
stats: Mutex::new(HostStats::default()),
backoff_duration: Mutex::new(Duration::from_millis(0)),
cache: DashMap::new(),
active_requests: DashMap::new(),
}
}

Expand Down Expand Up @@ -142,26 +149,22 @@ impl Host {
let mut url = request.url().clone();
url.set_fragment(None);
let uri = Uri::from(url);

let _permit = self.acquire_semaphore().await;
let _uri_guard = self.lock_uri_mutex(uri.clone()).await;

if let Some(cached) = self.get_cached_status(&uri, needs_body) {
self.record_cache_hit();
return Ok(cached);
}

self.record_cache_miss();
let _permit = self.acquire_semaphore().await;

self.await_backoff().await;

if let Some(rate_limiter) = &self.rate_limiter {
rate_limiter.until_ready().await;
}

if let Some(cached) = self.get_cached_status(&uri, needs_body) {
self.record_cache_hit();
return Ok(cached);
}

self.record_cache_miss();
self.perform_request(request, uri, needs_body).await
}

Expand Down Expand Up @@ -209,6 +212,19 @@ impl Host {
}
}

/// Get a [`tokio::sync::OwnedMutexGuard<()>`]
/// to prevent concurrent requests to identical [`Uri`]s.
async fn lock_uri_mutex(&self, uri: Uri) -> tokio::sync::OwnedMutexGuard<()> {
let uri_mutex = self
.active_requests
.entry(uri)
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone();

uri_mutex.lock_owned().await
}

/// Enforce the maximum concurrency of this host
async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
self.semaphore
.acquire()
Expand Down