Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce code duplication / unnecessary branches #74

Merged
merged 2 commits into from
Aug 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 39 additions & 48 deletions src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ use core::task::{Context, Poll};
/// ```
#[cfg(feature = "std")]
pub fn block_on<T>(future: impl Future<Output = T>) -> T {
use std::cell::RefCell;
use std::task::Waker;
use core::cell::RefCell;
use core::task::Waker;

use parking::Parker;
use waker_fn::waker_fn;
Expand All @@ -77,33 +77,28 @@ pub fn block_on<T>(future: impl Future<Output = T>) -> T {

CACHE.with(|cache| {
// Try grabbing the cached parker and waker.
match cache.try_borrow_mut() {
let tmp_cached;
let tmp_fresh;
let (parker, waker) = match cache.try_borrow_mut() {
Ok(cache) => {
// Use the cached parker and waker.
let (parker, waker) = &*cache;
let cx = &mut Context::from_waker(waker);

// Keep polling until the future is ready.
loop {
match future.as_mut().poll(cx) {
Poll::Ready(output) => return output,
Poll::Pending => parker.park(),
}
}
tmp_cached = cache;
&*tmp_cached
}
Err(_) => {
// Looks like this is a recursive `block_on()` call.
// Create a fresh parker and waker.
let (parker, waker) = parker_and_waker();
let cx = &mut Context::from_waker(&waker);

// Keep polling until the future is ready.
loop {
match future.as_mut().poll(cx) {
Poll::Ready(output) => return output,
Poll::Pending => parker.park(),
}
}
tmp_fresh = parker_and_waker();
&tmp_fresh
}
};

let cx = &mut Context::from_waker(waker);
// Keep polling until the future is ready.
loop {
match future.as_mut().poll(cx) {
Poll::Ready(output) => return output,
Poll::Pending => parker.park(),
}
}
})
Expand Down Expand Up @@ -289,6 +284,18 @@ pin_project! {
}
}

/// Extracts the contents of two options and zips them, handling `(Some(_), None)` cases
fn take_zip_from_parts<T1, T2>(o1: &mut Option<T1>, o2: &mut Option<T2>) -> Poll<(T1, T2)> {
match (o1.take(), o2.take()) {
(Some(t1), Some(t2)) => Poll::Ready((t1, t2)),
(o1x, o2x) => {
*o1 = o1x;
*o2 = o2x;
Poll::Pending
}
}
}

impl<F1, F2> Future for Zip<F1, F2>
where
F1: Future,
Expand All @@ -311,11 +318,7 @@ where
}
}

if this.output1.is_some() && this.output2.is_some() {
Poll::Ready((this.output1.take().unwrap(), this.output2.take().unwrap()))
} else {
Poll::Pending
}
take_zip_from_parts(this.output1, this.output2)
}
}

Expand All @@ -333,7 +336,7 @@ where
/// assert_eq!(future::try_zip(a, b).await, Err(2));
/// # })
/// ```
pub fn try_zip<T1, T2, E, F1, F2>(future1: F1, future2: F2) -> TryZip<F1, F2>
pub fn try_zip<T1, T2, E, F1, F2>(future1: F1, future2: F2) -> TryZip<F1, T1, F2, T2>
where
F1: Future<Output = Result<T1, E>>,
F2: Future<Output = Result<T2, E>>,
Expand All @@ -350,21 +353,17 @@ pin_project! {
/// Future for the [`try_zip()`] function.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct TryZip<F1, F2>
where
F1: Future,
F2: Future,
{
pub struct TryZip<F1, T1, F2, T2> {
#[pin]
future1: F1,
output1: Option<F1::Output>,
output1: Option<T1>,
#[pin]
future2: F2,
output2: Option<F2::Output>,
output2: Option<T2>,
}
}

impl<T1, T2, E, F1, F2> Future for TryZip<F1, F2>
impl<T1, T2, E, F1, F2> Future for TryZip<F1, T1, F2, T2>
where
F1: Future<Output = Result<T1, E>>,
F2: Future<Output = Result<T2, E>>,
Expand All @@ -377,7 +376,7 @@ where
if this.output1.is_none() {
if let Poll::Ready(out) = this.future1.poll(cx) {
match out {
Ok(t) => *this.output1 = Some(Ok(t)),
Ok(t) => *this.output1 = Some(t),
Err(err) => return Poll::Ready(Err(err)),
}
}
Expand All @@ -386,21 +385,13 @@ where
if this.output2.is_none() {
if let Poll::Ready(out) = this.future2.poll(cx) {
match out {
Ok(t) => *this.output2 = Some(Ok(t)),
Ok(t) => *this.output2 = Some(t),
Err(err) => return Poll::Ready(Err(err)),
}
}
}

if this.output1.is_some() && this.output2.is_some() {
let res1 = this.output1.take().unwrap();
let res2 = this.output2.take().unwrap();
let t1 = res1.map_err(|_| unreachable!()).unwrap();
let t2 = res2.map_err(|_| unreachable!()).unwrap();
Poll::Ready(Ok((t1, t2)))
} else {
Poll::Pending
}
take_zip_from_parts(this.output1, this.output2).map(Ok)
}
}

Expand Down