Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 81 additions & 56 deletions mea/src/internal/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Semaphore {
}
}

/// Decrease a semaphore's permits by a maximum of `n`.
/// Decrease the semaphore's permits by a maximum of `n`.
///
/// Return the number of permits that were actually reduced.
pub(crate) fn forget(&self, n: usize) -> usize {
Expand All @@ -95,6 +95,14 @@ impl Semaphore {
}
}

/// Decrease the semaphore's permits by `n`.
///
/// If the semaphore has not enough permits, enqueue front an empty waiter to consume the
/// permits.
pub(crate) fn forget_exact(&self, n: usize) {
acquired_or_enqueue(self, n, &mut None, None, false);
}

/// Acquires `n` permits from the semaphore.
pub(crate) async fn acquire(&self, n: usize) {
let fut = Acquire {
Expand Down Expand Up @@ -126,8 +134,6 @@ impl Semaphore {
Some(waiter) => {
if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
} else {
unreachable!("waker was removed from the list without a waker");
}
}
}
Expand Down Expand Up @@ -165,8 +171,6 @@ impl Semaphore {
Some(waiter) => {
if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
} else {
unreachable!("waker was removed from the list without a waker");
}
}
}
Expand Down Expand Up @@ -261,60 +265,81 @@ impl Future for Acquire<'_> {
// not yet enqueued
let needed = *permits;

let mut acquired = 0;
let mut current = semaphore.permits.load(Ordering::Acquire);
let mut lock = None;

let mut waiters = loop {
let mut remaining = 0;
let total = current.checked_add(acquired).expect("permits overflow");
let (next, acq) = if total >= needed {
let next = current - (needed - acquired);
(next, needed - acquired)
} else {
remaining = (needed - acquired) - current;
(0, current)
};

if remaining > 0 && lock.is_none() {
// No permits were immediately available, so this permit will
// (probably) need to wait. We'll need to acquire a lock on the
// wait queue before continuing. We need to do this _before_ the
// CAS that sets the new value of the semaphore's `permits`
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
lock = Some(semaphore.waiters.lock());
}

match semaphore.permits.compare_exchange(
current,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
acquired += acq;
if remaining == 0 {
*done = true;
return Poll::Ready(());
}
break lock.expect("lock not acquired");
}
Err(actual) => current = actual,
}
};

waiters.register_waiter(index, |node| match node {
None => Some(WaitNode {
permits: needed - acquired,
waker: Some(cx.waker().clone()),
}),
Some(node) => unreachable!("unexpected node: {:?}", node),
});
if acquired_or_enqueue(semaphore, needed, index, Some(cx.waker()), true) {
*done = true;
return Poll::Ready(());
}
}
};

Poll::Pending
}
}

/// Returns `true` if successfully acquired the semaphore; `false` otherwise.
fn acquired_or_enqueue(
semaphore: &Semaphore,
needed: usize,
idx: &mut Option<usize>,
waker: Option<&Waker>,
enqueue_last: bool,
) -> bool {
let mut acquired = 0;
let mut current = semaphore.permits.load(Ordering::Acquire);
let mut lock = None;

let mut waiters = loop {
let mut remaining = 0;
let total = current.checked_add(acquired).expect("permits overflow");
let (next, acq) = if total >= needed {
let next = current - (needed - acquired);
(next, needed - acquired)
} else {
remaining = (needed - acquired) - current;
(0, current)
};

if remaining > 0 && lock.is_none() {
// No permits were immediately available, so this permit will
// (probably) need to wait. We'll need to acquire a lock on the
// wait queue before continuing. We need to do this _before_ the
// CAS that sets the new value of the semaphore's `permits`
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
lock = Some(semaphore.waiters.lock());
}

match semaphore
.permits
.compare_exchange(current, next, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
acquired += acq;
if remaining == 0 {
return true;
}
// SAFETY: remaining > 0, lock must be Some
break lock.unwrap();
}
Err(actual) => current = actual,
}
};

if enqueue_last {
waiters.register_waiter_to_tail(idx, || {
Some(WaitNode {
permits: needed - acquired,
waker: waker.cloned(),
})
});
} else {
waiters.register_waiter_to_head(idx, || {
Some(WaitNode {
permits: needed - acquired,
waker: waker.cloned(),
})
});
}
false
}
70 changes: 46 additions & 24 deletions mea/src/internal/waitlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,54 @@ impl<T> WaitList<T> {
Self { guard, nodes }
}

/// Registers a waiter to the head of the wait list.
///
/// # Panic
///
/// Panics if `idx` is `Some`.
pub(crate) fn register_waiter_to_head(
&mut self,
idx: &mut Option<usize>,
f: impl FnOnce() -> Option<T>,
) {
assert!(idx.is_none());

let stat = f();
let prev_head = self.nodes[self.guard].next;
let new_node = Node {
prev: self.guard,
next: prev_head,
stat,
};
let new_key = self.nodes.insert(new_node);
self.nodes[self.guard].next = new_key;
self.nodes[prev_head].prev = new_key;
*idx = Some(new_key);
}

/// Registers a waiter to the tail of the wait list.
pub(crate) fn register_waiter(
///
/// # Panic
///
/// Panics if `idx` is `Some`.
pub(crate) fn register_waiter_to_tail(
&mut self,
idx: &mut Option<usize>,
f: impl FnOnce(Option<&T>) -> Option<T>,
f: impl FnOnce() -> Option<T>,
) {
match *idx {
None => {
let stat = f(None);
let prev_tail = self.nodes[self.guard].prev;
let new_node = Node {
prev: prev_tail,
next: self.guard,
stat,
};
let new_key = self.nodes.insert(new_node);
self.nodes[self.guard].prev = new_key;
self.nodes[prev_tail].next = new_key;
*idx = Some(new_key);
}
Some(key) => {
debug_assert_ne!(key, self.guard);
if let Some(stat) = f(self.nodes[key].stat.as_ref()) {
self.nodes[key].stat = Some(stat);
}
}
}
assert!(idx.is_none());

let stat = f();
let prev_tail = self.nodes[self.guard].prev;
let new_node = Node {
prev: prev_tail,
next: self.guard,
stat,
};
let new_key = self.nodes.insert(new_node);
self.nodes[self.guard].prev = new_key;
self.nodes[prev_tail].next = new_key;
*idx = Some(new_key);
}

/// Removes a previously registered waker from the wait list.
Expand All @@ -79,7 +100,8 @@ impl<T> WaitList<T> {
idx: usize,
f: impl FnOnce(&mut T) -> bool,
) -> Option<&mut T> {
debug_assert_ne!(idx, self.guard);
assert_ne!(idx, self.guard);

// SAFETY: `idx` is a valid key + non-guard node always has `Some(stat)`
fn retrieve_stat<T>(node: &mut Node<T>) -> &mut T {
node.stat.as_mut().unwrap()
Expand Down
27 changes: 27 additions & 0 deletions mea/src/semaphore/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,33 @@ impl Semaphore {
self.s.forget(n)
}

/// Reduces the semaphore's permits by exactly `n`.
///
/// If the semaphore has not enough permits, this would enqueue front an empty waiter to
/// consume the permits, which ensures the permits are reduced by exactly `n`.
///
/// This is useful when you want to permanently remove permits from the semaphore.
///
/// # Examples
///
/// ```
/// use mea::semaphore::Semaphore;
///
/// let sem = Semaphore::new(5);
/// sem.forget_exact(3); // Removes 3 permits
/// assert_eq!(sem.available_permits(), 2);
///
/// // Trying to forget more permits than available
/// sem.forget_exact(3); // Only removes remaining 2 permits
/// assert_eq!(sem.available_permits(), 0);
///
/// sem.release(5); // Adds 5 permits
/// assert_eq!(sem.available_permits(), 4); // Only 4 permits are available
/// ```
pub fn forget_exact(&self, n: usize) {
self.s.forget_exact(n);
}

/// Adds `n` new permits to the semaphore.
///
/// # Panics
Expand Down
26 changes: 26 additions & 0 deletions mea/src/semaphore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;
use std::vec::Vec;

use super::*;
use crate::latch::Latch;

#[test]
fn no_permits() {
Expand Down Expand Up @@ -135,3 +136,28 @@ fn try_acquire_concurrently() {
drop(p1);
assert_eq!(s.available_permits(), 1);
}

#[tokio::test]
async fn acquire_then_forget_exact() {
let s = Arc::new(Semaphore::new(5));
s.forget_exact(3);
assert_eq!(s.available_permits(), 2);

let acquired = Arc::new(Latch::new(1));

let acquired_clone = acquired.clone();
let s_clone = s.clone();
tokio::spawn(async move {
let _p = s_clone.acquire(3).await;
acquired_clone.count_down();
});
assert!(acquired.try_wait().is_err());

s.forget_exact(2);
s.release(2);
assert!(acquired.try_wait().is_err());

s.release(1);
acquired.wait().await;
assert_eq!(s.available_permits(), 3);
}