Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: make AtomicWaker panic safe #3689

Merged
merged 15 commits into from
Nov 2, 2021
Merged
78 changes: 69 additions & 9 deletions tokio/src/sync/task/atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{self, AtomicUsize};

use std::fmt;
use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
use std::task::Waker;

Expand All @@ -27,6 +28,9 @@ pub(crate) struct AtomicWaker {
waker: UnsafeCell<Option<Waker>>,
}

impl RefUnwindSafe for AtomicWaker {}
impl UnwindSafe for AtomicWaker {}

// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell
// stores a `Waker` value produced by calls to `register` and many threads can
// race to take the waker by calling `wake`.
Expand Down Expand Up @@ -84,7 +88,7 @@ pub(crate) struct AtomicWaker {
// back to `WAITING`. This transition must succeed as, at this point, the state
// cannot be transitioned by another thread.
//
// If the thread is unable to obtain the lock, the `WAKING` bit is still.
// If the thread is unable to obtain the lock, the `WAKING` bit is still set.
// This is because it has either been set by the current thread but the previous
// value included the `REGISTERING` bit **or** a concurrent thread is in the
// `WAKING` critical section. Either way, no action must be taken.
Expand Down Expand Up @@ -171,15 +175,35 @@ impl AtomicWaker {
where
W: WakerRef,
{
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> {
std::panic::catch_unwind(AssertUnwindSafe(f))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Waker's are UnwindSafe and the only things captured in closures below, so it might be possible to add that as a bound to F: UnwindSafe + FnOnce() -> R instead of AssertUnwindSafe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it might have been a previous version which was unhappy. I can try changing it to use catch_unwind directly.

}

match self
.state
.compare_exchange(WAITING, REGISTERING, Acquire, Acquire)
.unwrap_or_else(|x| x)
{
WAITING => {
unsafe {
// Locked acquired, update the waker cell
self.waker.with_mut(|t| *t = Some(waker.into_waker()));
// If `into_waker` panics (because it's code outside of
// AtomicWaker) we need to prime a guard that is called on
// unwind to restore the waker to a WAITING state. Otherwise
// any future calls to register will incorrectly be stuck
// believing it's being updated by someone else.
let new_waker_or_panic = catch_unwind(move || waker.into_waker());

// Set the field to contain the new waker, or if
// `into_waker` panicked, leave the old value.
let mut maybe_panic = None;
let mut old_waker = None;
match new_waker_or_panic {
Ok(new_waker) => {
old_waker = self.waker.with_mut(|t| (*t).take());
self.waker.with_mut(|t| *t = Some(new_waker));
}
Err(panic) => maybe_panic = Some(panic),
}

// Release the lock. If the state transitioned to include
// the `WAKING` bit, this means that a wake has been
Expand All @@ -193,33 +217,67 @@ impl AtomicWaker {
.compare_exchange(REGISTERING, WAITING, AcqRel, Acquire);

match res {
Ok(_) => {}
Ok(_) => {
// We don't want to give the caller the panic if it
// was someone else who put in that waker.
Copy link
Contributor Author

@udoprog udoprog Aug 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this end in a situation where a panic in the Drop impl of a Waker is ignored, and if so is that OK?

My general stance towards panics is that they should be propagated somewhere - otherwise errors which cause them might end up going unnoticed and contribute to other unpredictable side effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that's the question really. There are certainly some places we don't want panics propagating, e.g. inside the runtime. The panic is printed to the console even if we drop it.

let _ = catch_unwind(move || {
drop(old_waker);
});
}
Err(actual) => {
// This branch can only be reached if a
// concurrent thread called `wake`. In this
// case, `actual` **must** be `REGISTERING |
// `WAKING`.
// WAKING`.
debug_assert_eq!(actual, REGISTERING | WAKING);

// Take the waker to wake once the atomic operation has
// completed.
let waker = self.waker.with_mut(|t| (*t).take()).unwrap();
let mut waker = self.waker.with_mut(|t| (*t).take());

// Just swap, because no one could change state
// while state == `Registering | `Waking`
self.state.swap(WAITING, AcqRel);

// The atomic swap was complete, now
// wake the waker and return.
waker.wake();
// If `into_waker` panicked, then the waker in the
// waker slot is actually the old waker.
if maybe_panic.is_some() {
old_waker = waker.take();
}

// We don't want to give the caller the panic if it
// was someone else who put in that waker.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same Q here.

if let Some(old_waker) = old_waker {
let _ = catch_unwind(move || {
old_waker.wake();
});
}

// The atomic swap was complete, now wake the waker
// and return.
//
// If this panics, we end up in a consumed state and
// return the panic to the caller.
if let Some(waker) = waker {
debug_assert!(maybe_panic.is_none());
waker.wake();
}
}
}

if let Some(panic) = maybe_panic {
// If `into_waker` panicked, return the panic to the caller.
resume_unwind(panic);
}
}
}
WAKING => {
// Currently in the process of waking the task, i.e.,
// `wake` is currently being called on the old waker.
// So, we call wake on the new waker.
//
// If this panics, someone else is responsible for restoring the
// state of the waker.
waker.wake();

// This is equivalent to a spin lock, so use a spin hint.
Expand All @@ -245,6 +303,8 @@ impl AtomicWaker {
/// If `register` has not been called yet, then this does nothing.
pub(crate) fn wake(&self) {
if let Some(waker) = self.take_waker() {
// If wake panics, we've consumed the waker which is a legitimate
// outcome.
waker.wake();
}
}
Expand Down
39 changes: 39 additions & 0 deletions tokio/src/sync/tests/atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,42 @@ fn wake_without_register() {

assert!(!waker.is_woken());
}

#[test]
fn atomic_waker_panic_safe() {
use std::panic;
use std::ptr;
use std::task::{RawWaker, RawWakerVTable, Waker};

static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| panic!("clone"),
|_| unimplemented!("wake"),
|_| unimplemented!("wake_by_ref"),
|_| (),
);

static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE),
|_| unimplemented!("wake"),
|_| unimplemented!("wake_by_ref"),
|_| (),
);

let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) };

let atomic_waker = AtomicWaker::new();

let panicking = panic::AssertUnwindSafe(&panicking);

let result = panic::catch_unwind(|| {
let panic::AssertUnwindSafe(panicking) = panicking;
atomic_waker.register_by_ref(panicking);
});

assert!(result.is_err());
assert!(atomic_waker.take_waker().is_none());

atomic_waker.register_by_ref(&nonpanicking);
assert!(atomic_waker.take_waker().is_some());
}
55 changes: 55 additions & 0 deletions tokio/src/sync/tests/loom_atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,58 @@ fn basic_notification() {
}));
});
}

#[test]
fn test_panicky_waker() {
use std::panic;
use std::ptr;
use std::task::{RawWaker, RawWakerVTable, Waker};

static PANICKING_VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ());

let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };

// If you're working with this test (and I sure hope you never have to!),
// uncomment the following section because there will be a lot of panics
// which would otherwise log.
//
// We can't however leaved it uncommented, because it's global.
// panic::set_hook(Box::new(|_| ()));

const NUM_NOTIFY: usize = 2;

loom::model(move || {
let chan = Arc::new(Chan {
num: AtomicUsize::new(0),
task: AtomicWaker::new(),
});

for _ in 0..NUM_NOTIFY {
let chan = chan.clone();

thread::spawn(move || {
chan.num.fetch_add(1, Relaxed);
chan.task.wake();
});
}

// Note: this panic should have no effect on the overall state of the
// waker and it should proceed as normal.
//
// A thread above might race to flag a wakeup, and a WAKING state will
// be preserved if this expected panic races with that so the below
// procedure should be allowed to continue uninterrupted.
let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking));

block_on(poll_fn(move |cx| {
chan.task.register_by_ref(cx.waker());

if NUM_NOTIFY == chan.num.load(Relaxed) {
return Ready(());
}

Pending
}));
});
}