diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index aa23dea7d3c..d7eb1d6b77e 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -368,6 +368,31 @@ impl Semaphore { assert_eq!(rem, 0); } + /// Decrease a semaphore's permits by a maximum of `n`. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, + /// return the number of permits that were actually reduced. + pub(crate) fn forget_permits(&self, n: usize) -> usize { + if n == 0 { + return 0; + } + + let mut curr_bits = self.permits.load(Acquire); + loop { + let curr = curr_bits >> Self::PERMIT_SHIFT; + let new = curr.saturating_sub(n); + match self.permits.compare_exchange_weak( + curr_bits, + new << Self::PERMIT_SHIFT, + AcqRel, + Acquire, + ) { + Ok(_) => return std::cmp::min(curr, n), + Err(actual) => curr_bits = actual, + }; + } + } + fn poll_acquire( &self, cx: &mut Context<'_>, diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 25e4134373c..d0ee12591ee 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -481,6 +481,14 @@ impl Semaphore { self.ll_sem.release(n); } + /// Decrease a semaphore's permits by a maximum of `n`. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, + /// return the number of permits that were actually reduced. + pub fn forget_permits(&self, n: usize) -> usize { + self.ll_sem.forget_permits(n) + } + /// Acquires a permit from the semaphore. /// /// If the semaphore has been closed, this returns an [`AcquireError`]. diff --git a/tokio/src/sync/tests/loom_semaphore_batch.rs b/tokio/src/sync/tests/loom_semaphore_batch.rs index 76a1bc00626..85cd584d2d4 100644 --- a/tokio/src/sync/tests/loom_semaphore_batch.rs +++ b/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -213,3 +213,31 @@ fn release_during_acquire() { assert_eq!(10, semaphore.available_permits()); }) } + +#[test] +fn concurrent_permit_updates() { + loom::model(move || { + let semaphore = Arc::new(Semaphore::new(5)); + let t1 = { + let semaphore = semaphore.clone(); + thread::spawn(move || semaphore.release(3)) + }; + let t2 = { + let semaphore = semaphore.clone(); + thread::spawn(move || { + semaphore + .try_acquire(1) + .expect("try_acquire should succeed") + }) + }; + let t3 = { + let semaphore = semaphore.clone(); + thread::spawn(move || semaphore.forget_permits(2)) + }; + + t1.join().unwrap(); + t2.join().unwrap(); + t3.join().unwrap(); + assert_eq!(semaphore.available_permits(), 5); + }) +} diff --git a/tokio/src/sync/tests/semaphore_batch.rs b/tokio/src/sync/tests/semaphore_batch.rs index 391797b3f66..09610ce71f2 100644 --- a/tokio/src/sync/tests/semaphore_batch.rs +++ b/tokio/src/sync/tests/semaphore_batch.rs @@ -287,3 +287,24 @@ fn release_permits_at_drop() { assert!(fut.as_mut().poll(&mut cx).is_pending()); } } + +#[test] +fn forget_permits_basic() { + let s = Semaphore::new(10); + assert_eq!(s.forget_permits(4), 4); + assert_eq!(s.available_permits(), 6); + assert_eq!(s.forget_permits(10), 6); + assert_eq!(s.available_permits(), 0); +} + +#[test] +fn update_permits_many_times() { + let s = Semaphore::new(5); + let mut acquire = task::spawn(s.acquire(7)); + assert_pending!(acquire.poll()); + s.release(5); + assert_ready_ok!(acquire.poll()); + assert_eq!(s.available_permits(), 3); + assert_eq!(s.forget_permits(3), 3); + assert_eq!(s.available_permits(), 0); +}