Skip to content

Commit

Permalink
Shared: fix shared futures missing wake up
Browse files Browse the repository at this point in the history
  • Loading branch information
Zekun Li authored and Zekun Li committed Nov 27, 2024
1 parent 7211cb7 commit 3b7eacb
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 48 deletions.
111 changes: 64 additions & 47 deletions futures-util/src/future/future/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::pin::Pin;
use std::ptr;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Acquire, SeqCst};
use std::sync::{Arc, Mutex, Weak};
use std::sync::{Arc, Mutex, MutexGuard, Weak};

/// Future for the [`shared`](super::FutureExt::shared) method.
#[must_use = "futures do nothing unless you `.await` or poll them"]
Expand Down Expand Up @@ -81,6 +81,7 @@ const IDLE: usize = 0;
const POLLING: usize = 1;
const COMPLETE: usize = 2;
const POISONED: usize = 3;
const AWAKEN_DURING_POLLING: usize = 4;

const NULL_WAKER_KEY: usize = usize::MAX;

Expand Down Expand Up @@ -197,36 +198,47 @@ where
}
}

impl<Fut> Inner<Fut>
where
Fut: Future,
Fut::Output: Clone,
{
/// Registers the current task to receive a wakeup when we are awoken.
fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
let mut wakers_guard = self.notifier.wakers.lock().unwrap();

let wakers_mut = wakers_guard.as_mut();

let wakers = match wakers_mut {
Some(wakers) => wakers,
None => return,
};

let new_waker = cx.waker();
/// Registers the current task to receive a wakeup when we are awoken.
fn record_waker(
wakers_guard: &mut MutexGuard<'_, Option<Slab<Option<Waker>>>>,
waker_key: &mut usize,
cx: &mut Context<'_>,
) {
let wakers = match wakers_guard.as_mut() {
Some(wakers) => wakers,
None => return,
};

let new_waker = cx.waker();

if *waker_key == NULL_WAKER_KEY {
*waker_key = wakers.insert(Some(new_waker.clone()));
} else {
match wakers[*waker_key] {
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
// Could use clone_from here, but Waker doesn't specialize it.
ref mut slot => *slot = Some(new_waker.clone()),
}
}
debug_assert!(*waker_key != NULL_WAKER_KEY);
}

if *waker_key == NULL_WAKER_KEY {
*waker_key = wakers.insert(Some(new_waker.clone()));
} else {
match wakers[*waker_key] {
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
// Could use clone_from here, but Waker doesn't specialize it.
ref mut slot => *slot = Some(new_waker.clone()),
/// Wakes all tasks that are registered to be woken.
fn wake_all(waker_guard: &mut MutexGuard<'_, Option<Slab<Option<Waker>>>>) {
if let Some(wakers) = waker_guard.as_mut() {
for (_key, opt_waker) in wakers {
if let Some(waker) = opt_waker.take() {
waker.wake();
}
}
debug_assert!(*waker_key != NULL_WAKER_KEY);
}
}

impl<Fut> Inner<Fut>
where
Fut: Future,
Fut::Output: Clone,
{
/// Safety: callers must first ensure that `inner.state`
/// is `COMPLETE`
unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
Expand Down Expand Up @@ -268,18 +280,22 @@ where
return unsafe { Poll::Ready(inner.take_or_clone_output()) };
}

inner.record_waker(&mut this.waker_key, cx);
// Guard the state transition with mutex too
let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
record_waker(&mut wakers_guard, &mut this.waker_key, cx);

match inner
let prev = inner
.notifier
.state
.compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
.unwrap_or_else(|x| x)
{
.unwrap_or_else(|x| x);
drop(wakers_guard);

match prev {
IDLE => {
// Lock acquired, fall through
}
POLLING => {
POLLING | AWAKEN_DURING_POLLING => {
// Another task is currently polling, at this point we just want
// to ensure that the waker for this task is registered
this.inner = Some(inner);
Expand Down Expand Up @@ -324,15 +340,21 @@ where

match poll_result {
Poll::Pending => {
if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
{
// Success
drop(reset);
this.inner = Some(inner);
return Poll::Pending;
} else {
unreachable!()
match inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst) {
Ok(POLLING) => {} // success
Err(AWAKEN_DURING_POLLING) => {
// waker has been called inside future.poll, need to wake any new wakers registered
let mut wakers = inner.notifier.wakers.lock().unwrap();
wake_all(&mut wakers);
let prev = inner.notifier.state.swap(IDLE, SeqCst);
assert_eq!(prev, AWAKEN_DURING_POLLING);
drop(wakers);
}
_ => unreachable!(),
}
drop(reset);
this.inner = Some(inner);
return Poll::Pending;
}
Poll::Ready(output) => output,
}
Expand Down Expand Up @@ -387,14 +409,9 @@ where

impl ArcWake for Notifier {
fn wake_by_ref(arc_self: &Arc<Self>) {
let wakers = &mut *arc_self.wakers.lock().unwrap();
if let Some(wakers) = wakers.as_mut() {
for (_key, opt_waker) in wakers {
if let Some(waker) = opt_waker.take() {
waker.wake();
}
}
}
let mut wakers = arc_self.wakers.lock().unwrap();
let _ = arc_self.state.compare_exchange(POLLING, AWAKEN_DURING_POLLING, SeqCst, SeqCst);
wake_all(&mut wakers);
}
}

Expand Down
53 changes: 52 additions & 1 deletion futures/tests/future_shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use futures::executor::{block_on, LocalPool};
use futures::future::{self, FutureExt, LocalFutureObj, TryFutureExt};
use futures::task::LocalSpawn;
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;
use std::task::{Context, Poll};
use std::thread;

struct CountClone(Rc<Cell<i32>>);
Expand Down Expand Up @@ -271,3 +273,52 @@ fn poll_while_panic() {
let _s = S {};
panic!("test_marker");
}

#[test]
fn shared_futures_woken_during_polling() {
async fn yield_now() {
/// Yield implementation
struct YieldNow {
yielded: bool,
}

impl Future for YieldNow {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
return Poll::Ready(());
}

self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}

YieldNow { yielded: false }.await
}
fn test() {
let f1 = yield_now().shared();
let f2 = f1.clone();
let x1 = thread::spawn(move || {
block_on(async move {
f1.now_or_never();
})
});
let x2 = thread::spawn(move || {
block_on(async move {
f2.await;
})
});
x1.join().ok();
x2.join().ok();
}

for _ in 0..10 {
print!(".");
for _ in 0..10000 {
test();
}
}
}

0 comments on commit 3b7eacb

Please sign in to comment.