Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: make AtomicWaker panic safe #3689

Merged
merged 15 commits into from
Nov 2, 2021
Merged
31 changes: 31 additions & 0 deletions tokio/src/sync/task/atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{self, AtomicUsize};

use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
use std::task::Waker;

Expand All @@ -27,6 +28,9 @@ pub(crate) struct AtomicWaker {
waker: UnsafeCell<Option<Waker>>,
}

impl RefUnwindSafe for AtomicWaker {}
impl UnwindSafe for AtomicWaker {}

// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell
// stores a `Waker` value produced by calls to `register` and many threads can
// race to take the waker by calling `wake.
Expand Down Expand Up @@ -178,8 +182,16 @@ impl AtomicWaker {
{
WAITING => {
unsafe {
// If `into_waker` panics (because it's code outside of
// AtomicWaker) we need to prime a guard that is called on
// unwind restore the waker to a WAITING state. Otherwise
// any future calls to register will incorrectly be stuck in
// a state where it believes it's being updated by someone
// else.
let guard = PanicGuard(&self.state);
// Locked acquired, update the waker cell
self.waker.with_mut(|t| *t = Some(waker.into_waker()));
std::mem::forget(guard);
udoprog marked this conversation as resolved.
Show resolved Hide resolved

// Release the lock. If the state transitioned to include
// the `WAKING` bit, this means that a wake has been
Expand Down Expand Up @@ -211,6 +223,8 @@ impl AtomicWaker {

// The atomic swap was complete, now
// wake the waker and return.
//
// If this panics, we end up in a consumed state.
waker.wake();
}
}
Expand All @@ -220,6 +234,9 @@ impl AtomicWaker {
// Currently in the process of waking the task, i.e.,
// `wake` is currently being called on the old waker.
// So, we call wake on the new waker.
//
// If this panics, someone else is responsible for restoring the
// state of the waker.
waker.wake();

// This is equivalent to a spin lock, so use a spin hint.
Expand All @@ -238,13 +255,27 @@ impl AtomicWaker {
debug_assert!(state == REGISTERING || state == REGISTERING | WAKING);
}
}

struct PanicGuard<'a>(&'a AtomicUsize);

impl Drop for PanicGuard<'_> {
fn drop(&mut self) {
// On panics, we restore straight back to WAITING. We just tried
// to replace the atomic waker, so it's a legitimate outcome
// that the net effect of the operation is a no-op (i.e. no
// wakeup was issued).
self.0.swap(WAITING, AcqRel);
}
}
udoprog marked this conversation as resolved.
Show resolved Hide resolved
}

/// Wakes the task that last called `register`.
///
/// If `register` has not been called yet, then this does nothing.
pub(crate) fn wake(&self) {
if let Some(waker) = self.take_waker() {
// If wake panics, we've consumed the waker which is a legitimate
// outcome.
waker.wake();
}
}
Expand Down
39 changes: 39 additions & 0 deletions tokio/src/sync/tests/atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,42 @@ fn wake_without_register() {

assert!(!waker.is_woken());
}

#[test]
fn atomic_waker_panic_safe() {
use std::panic;
use std::ptr;
use std::task::{RawWaker, RawWakerVTable, Waker};

static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| panic!("clone"),
|_| unimplemented!("wake"),
|_| unimplemented!("wake_by_ref"),
|_| (),
);

static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE),
|_| unimplemented!("wake"),
|_| unimplemented!("wake_by_ref"),
|_| (),
);

let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) };

let atomic_waker = AtomicWaker::new();

let panicking = panic::AssertUnwindSafe(&panicking);

let result = panic::catch_unwind(|| {
let panic::AssertUnwindSafe(panicking) = panicking;
atomic_waker.register_by_ref(panicking);
});

assert!(result.is_err());
assert!(atomic_waker.take_waker().is_none());

atomic_waker.register_by_ref(&nonpanicking);
assert!(atomic_waker.take_waker().is_some());
}