Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 21 additions & 78 deletions tokio/src/runtime/time/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
//!
//! Each timer has a state field associated with it. This field contains either
//! the current scheduled time, or a special flag value indicating its state.
//! This state can either indicate that the timer is on the 'pending' queue (and
//! thus will be fired with an `Ok(())` result soon) or that it has already been
//! fired/deregistered.
//! This state can either indicate that the timer is firing (and thus will be fired
//! with an `Ok(())` result soon) or that it has already been fired/deregistered.
//!
//! This single state field allows for code that is firing the timer to
//! synchronize with any racing `reset` calls reliably.
Expand All @@ -49,10 +48,10 @@
//! There is of course a race condition between timer reset and timer
//! expiration. If the driver fails to observe the updated expiration time, it
//! could trigger expiration of the timer too early. However, because
//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as pending.
//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as firing.
//!
//! [mark_pending]: TimerHandle::mark_pending
//! [mark_firing]: TimerHandle::mark_firing

use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
Expand All @@ -70,9 +69,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};

type TimerResult = Result<(), crate::time::error::Error>;

const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
pub(super) const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_FIRING: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_FIRING;
/// The largest safe integer to use for ticks.
///
/// This value should be updated if any other signal values are added above.
Expand Down Expand Up @@ -123,10 +122,6 @@ impl StateCell {
}
}

fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}

/// Returns the current expiration time, or None if not currently scheduled.
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
Expand Down Expand Up @@ -162,26 +157,28 @@ impl StateCell {
}
}

/// Marks this timer as being moved to the pending list, if its scheduled
/// time is not after `not_after`.
/// Marks this timer firing, if its scheduled time is not after `not_after`.
///
/// If the timer is scheduled for a time after `not_after`, returns an Err
/// containing the current scheduled time.
///
/// SAFETY: Must hold the driver lock.
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
// Quick initial debug check to see if the timer is already fired. Since
// firing the timer can only happen with the driver lock held, we know
// we shouldn't be able to "miss" a transition to a fired state, even
// with relaxed ordering.
let mut cur_state = self.state.load(Ordering::Relaxed);

loop {
// Because its state is STATE_DEREGISTERED, it has been fired.
if cur_state == STATE_DEREGISTERED {
break Err(cur_state);
}
// improve the error message for things like
// https://github.com/tokio-rs/tokio/issues/3675
assert!(
cur_state < STATE_MIN_VALUE,
"mark_pending called when the timer entry is in an invalid state"
"mark_firing called when the timer entry is in an invalid state"
);

if cur_state > not_after {
Expand All @@ -190,7 +187,7 @@ impl StateCell {

match self.state.compare_exchange_weak(
cur_state,
STATE_PENDING_FIRE,
STATE_FIRING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Expand Down Expand Up @@ -337,11 +334,6 @@ pub(crate) struct TimerShared {
/// Only accessed under the entry lock.
pointers: linked_list::Pointers<TimerShared>,

/// The expiration time for which this entry is currently registered.
/// Generally owned by the driver, but is accessed by the entry when not
/// registered.
cached_when: AtomicU64,

/// Current state. This records whether the timer entry is currently under
/// the ownership of the driver, and if not, its current state (not
/// complete, fired, error, etc).
Expand All @@ -356,7 +348,6 @@ unsafe impl Sync for TimerShared {}
impl std::fmt::Debug for TimerShared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerShared")
.field("cached_when", &self.cached_when.load(Ordering::Relaxed))
.field("state", &self.state)
.finish()
}
Expand All @@ -374,40 +365,12 @@ impl TimerShared {
pub(super) fn new(shard_id: u32) -> Self {
Self {
shard_id,
cached_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
_p: PhantomPinned,
}
}

/// Gets the cached time-of-expiration value.
pub(super) fn cached_when(&self) -> u64 {
// Cached-when is only accessed under the driver lock, so we can use relaxed
self.cached_when.load(Ordering::Relaxed)
}

/// Gets the true time-of-expiration value, and copies it into the cached
/// time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();

self.cached_when.store(true_when, Ordering::Relaxed);

true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
unsafe fn set_cached_when(&self, when: u64) {
self.cached_when.store(when, Ordering::Relaxed);
}

/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
Expand All @@ -420,7 +383,6 @@ impl TimerShared {
/// in the timer wheel.
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.cached_when.store(t, Ordering::Relaxed);
}

/// Sets the true time-of-expiration only if it is after the current.
Expand Down Expand Up @@ -590,16 +552,8 @@ impl TimerEntry {
}

impl TimerHandle {
pub(super) unsafe fn cached_when(&self) -> u64 {
unsafe { self.inner.as_ref().cached_when() }
}

pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}

pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
pub(super) unsafe fn true_when(&self) -> u64 {
unsafe { self.inner.as_ref().true_when() }
}

/// Forcibly sets the true and cached expiration times to the given tick.
Expand All @@ -610,27 +564,16 @@ impl TimerHandle {
self.inner.as_ref().set_expiration(tick);
}

/// Attempts to mark this entry as pending. If the expiration time is after
/// Attempts to mark this entry as firing. If the expiration time is after
/// `not_after`, however, returns an Err with the current expiration time.
///
/// If an `Err` is returned, the `cached_when` value will be updated to this
/// new expiration time.
///
/// SAFETY: The caller must ensure that the handle remains valid, the driver
/// lock is held, and that the timer is not in any wheel linked lists.
/// After returning Ok, the entry must be added to the pending list.
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
// mark this as being on the pending queue in cached_when
self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
self.inner.as_ref().state.mark_firing(not_after)
}

/// Attempts to transition to a terminal state. If the state is already a
Expand Down
60 changes: 45 additions & 15 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

mod entry;
pub(crate) use entry::TimerEntry;
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION};
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED};

mod handle;
pub(crate) use self::handle::Handle;
Expand Down Expand Up @@ -319,23 +319,53 @@ impl Handle {
now = lock.elapsed();
}

while let Some(entry) = lock.poll(now) {
debug_assert!(unsafe { entry.is_pending() });

// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
drop(lock);

waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
while let Some(expiration) = lock.poll(now) {
lock.set_elapsed(expiration.deadline);
// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = TimerShared::new(id);
pin!(guard);
let guard_handle = guard.as_ref().get_ref().handle();

// * This list will be still guarded by the lock of the Wheel with the specefied id.
// `EntryWaitersList` wrapper makes sure we hold the lock to modify it.
// * This wrapper will empty the list on drop. It is critical for safety
// that we will not leave any list entry with a pointer to the local
// guard node after this function returns / panics.
// Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory.
let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) };

while let Some(entry) = list.pop_back_locked(&mut lock) {
let deadline = expiration.deadline;
// Try to expire the entry; this is cheap (doesn't synchronize) if
// the timer is not expired, and updates cached_when.
match unsafe { entry.mark_firing(deadline) } {
Ok(()) => {
// Entry was expired.
// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock,
// we must do this with the lock temporarily dropped.
drop(lock);
waker_list.wake_all();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about panics here. What happens to entries still in the guarded list?

I think some other examples of this wrap the guarded list in something that removes all entries from the list without waking their wakers.

Copy link
Copy Markdown
Contributor Author

@wathenjiang wathenjiang Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We shoud ensure the list is empty before dropping it and no double panic, even if it occurs panic. I will wrap it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, we use EntryWaitersList instead. It will address our purpose of memory safety.


lock = self.inner.lock_sharded_wheel(id);
}
}
}
Err(state) if state == STATE_DEREGISTERED => {}
Err(state) => {
// Safety: This Entry has not expired.
unsafe { lock.reinsert_entry(entry, deadline, state) };
}
}
}
lock.occupied_bit_maintain(&expiration);
}

let next_wake_up = lock.poll_at();
drop(lock);

Expand Down
22 changes: 12 additions & 10 deletions tokio/src/runtime/time/wheel/level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub(crate) struct Level {
}

/// Indicates when a slot must be processed next.
#[derive(Debug)]
pub(crate) struct Expiration {
/// The level containing the slot.
pub(crate) level: usize,
Expand Down Expand Up @@ -81,7 +80,7 @@ impl Level {
// pseudo-ring buffer, and we rotate around them indefinitely. If we
// compute a deadline before now, and it's the top level, it
// therefore means we're actually looking at a slot in the future.
debug_assert_eq!(self.level, super::NUM_LEVELS - 1);
debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX);

deadline += level_range;
}
Expand Down Expand Up @@ -120,31 +119,34 @@ impl Level {
}

pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) {
let slot = slot_for(item.cached_when(), self.level);
let slot = slot_for(item.true_when(), self.level);

self.slot[slot].push_front(item);

self.occupied |= occupied_bit(slot);
}

pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) {
let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level);
let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level);

unsafe { self.slot[slot].remove(item) };
if self.slot[slot].is_empty() {
// The bit is currently set
debug_assert!(self.occupied & occupied_bit(slot) != 0);

// Unset the bit
self.occupied ^= occupied_bit(slot);
}
}

pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList {
self.occupied &= !occupied_bit(slot);

pub(super) fn take_slot(&mut self, slot: usize) -> EntryList {
std::mem::take(&mut self.slot[slot])
}

pub(super) fn occupied_bit_maintain(&mut self, slot: usize) {
if self.slot[slot].is_empty() {
self.occupied &= !occupied_bit(slot);
} else {
self.occupied |= occupied_bit(slot);
}
}
}

impl fmt::Debug for Level {
Expand Down
Loading