Skip to content

Commit

Permalink
Reduce unsafe code in array queue/bounded channel
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed Jan 22, 2022
1 parent 544b4e8 commit aca2e31
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 75 deletions.
53 changes: 18 additions & 35 deletions crossbeam-channel/src/flavors/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
//! - <https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub>
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::{self, AtomicUsize, Ordering};
Expand Down Expand Up @@ -72,7 +71,7 @@ pub(crate) struct Channel<T> {
tail: CachePadded<AtomicUsize>,

/// The buffer holding slots.
buffer: *mut Slot<T>,
buffer: Box<[Slot<T>]>,

/// The channel capacity.
cap: usize,
Expand All @@ -88,9 +87,6 @@ pub(crate) struct Channel<T> {

/// Receivers waiting while the channel is empty and not disconnected.
receivers: SyncWaker,

/// Indicates that dropping a `Channel<T>` may drop values of type `T`.
_marker: PhantomData<T>,
}

impl<T> Channel<T> {
Expand All @@ -109,18 +105,15 @@ impl<T> Channel<T> {

// Allocate a buffer of `cap` slots initialized
// with stamps.
let buffer = {
let boxed: Box<[Slot<T>]> = (0..cap)
.map(|i| {
// Set the stamp to `{ lap: 0, mark: 0, index: i }`.
Slot {
stamp: AtomicUsize::new(i),
msg: UnsafeCell::new(MaybeUninit::uninit()),
}
})
.collect();
Box::into_raw(boxed) as *mut Slot<T>
};
let buffer: Box<[Slot<T>]> = (0..cap)
.map(|i| {
// Set the stamp to `{ lap: 0, mark: 0, index: i }`.
Slot {
stamp: AtomicUsize::new(i),
msg: UnsafeCell::new(MaybeUninit::uninit()),
}
})
.collect();

Channel {
buffer,
Expand All @@ -131,7 +124,6 @@ impl<T> Channel<T> {
tail: CachePadded::new(AtomicUsize::new(tail)),
senders: SyncWaker::new(),
receivers: SyncWaker::new(),
_marker: PhantomData,
}
}

Expand Down Expand Up @@ -163,7 +155,8 @@ impl<T> Channel<T> {
let lap = tail & !(self.one_lap - 1);

// Inspect the corresponding slot.
let slot = unsafe { &*self.buffer.add(index) };
debug_assert!(index < self.buffer.len());
let slot = unsafe { self.buffer.get_unchecked(index) };
let stamp = slot.stamp.load(Ordering::Acquire);

// If the tail and the stamp match, we may attempt to push.
Expand Down Expand Up @@ -245,7 +238,8 @@ impl<T> Channel<T> {
let lap = head & !(self.one_lap - 1);

// Inspect the corresponding slot.
let slot = unsafe { &*self.buffer.add(index) };
debug_assert!(index < self.buffer.len());
let slot = unsafe { self.buffer.get_unchecked(index) };
let stamp = slot.stamp.load(Ordering::Acquire);

// If the the stamp is ahead of the head by 1, we may attempt to pop.
Expand Down Expand Up @@ -540,23 +534,12 @@ impl<T> Drop for Channel<T> {
};

unsafe {
let p = {
let slot = &mut *self.buffer.add(index);
let msg = &mut *slot.msg.get();
msg.as_mut_ptr()
};
p.drop_in_place();
debug_assert!(index < self.buffer.len());
let slot = self.buffer.get_unchecked_mut(index);
let msg = &mut *slot.msg.get();
msg.as_mut_ptr().drop_in_place();
}
}

// Finally, deallocate the buffer, but don't run any destructors.
unsafe {
// Create a slice from the buffer to make
// a fat pointer. Then, use Box::from_raw
// to deallocate it.
let ptr = std::slice::from_raw_parts_mut(self.buffer, self.cap) as *mut [Slot<T>];
Box::from_raw(ptr);
}
}
}

Expand Down
87 changes: 85 additions & 2 deletions crossbeam-channel/tests/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![feature(vec_into_raw_parts)]
//! Tests for the array channel flavor.
#![cfg(not(miri))] // TODO: many assertions failed due to Miri is slow

use std::any::Any;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
Expand Down Expand Up @@ -254,7 +253,13 @@ fn recv_after_disconnect() {

#[test]
fn len() {
#[cfg(miri)]
const COUNT: usize = 250;
#[cfg(not(miri))]
const COUNT: usize = 25_000;
#[cfg(miri)]
const CAP: usize = 100;
#[cfg(not(miri))]
const CAP: usize = 1000;

let (s, r) = bounded(CAP);
Expand Down Expand Up @@ -347,6 +352,9 @@ fn disconnect_wakes_receiver() {

#[test]
fn spsc() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 100_000;

let (s, r) = bounded(3);
Expand All @@ -369,6 +377,9 @@ fn spsc() {

#[test]
fn mpmc() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 25_000;
const THREADS: usize = 4;

Expand Down Expand Up @@ -401,6 +412,9 @@ fn mpmc() {

#[test]
fn stress_oneshot() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 10_000;

for _ in 0..COUNT {
Expand All @@ -416,6 +430,9 @@ fn stress_oneshot() {

#[test]
fn stress_iter() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 100_000;

let (request_s, request_r) = bounded(1);
Expand Down Expand Up @@ -481,6 +498,7 @@ fn stress_timeout_two_threads() {
.unwrap();
}

#[cfg_attr(miri, ignore)] // Miri is too slow
#[test]
fn drops() {
const RUNS: usize = 100;
Expand Down Expand Up @@ -533,6 +551,9 @@ fn drops() {

#[test]
fn linearizable() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 25_000;
const THREADS: usize = 4;

Expand All @@ -553,6 +574,9 @@ fn linearizable() {

#[test]
fn fairness() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 10_000;

let (s1, r1) = bounded::<()>(COUNT);
Expand All @@ -575,6 +599,9 @@ fn fairness() {

#[test]
fn fairness_duplicates() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 10_000;

let (s, r) = bounded::<()>(COUNT);
Expand Down Expand Up @@ -619,6 +646,9 @@ fn recv_in_send() {

#[test]
fn channel_through_channel() {
#[cfg(miri)]
const COUNT: usize = 100;
#[cfg(not(miri))]
const COUNT: usize = 1000;

type T = Box<dyn Any + Send>;
Expand Down Expand Up @@ -654,3 +684,56 @@ fn channel_through_channel() {
})
.unwrap();
}

#[test]
fn panic_on_drop() {
struct Msg1<'a>(&'a mut bool);
impl Drop for Msg1<'_> {
fn drop(&mut self) {
if *self.0 && !std::thread::panicking() {
panic!("double drop");
} else {
*self.0 = true;
}
}
}

struct Msg2<'a>(&'a mut bool);
impl Drop for Msg2<'_> {
fn drop(&mut self) {
if *self.0 {
panic!("double drop");
} else {
*self.0 = true;
panic!("first drop");
}
}
}

// normal
let (s, r) = bounded(2);
let (mut a, mut b) = (false, false);
s.send(Msg1(&mut a)).unwrap();
s.send(Msg1(&mut b)).unwrap();
drop(s);
drop(r);
assert!(a);
assert!(b);

// panic on drop
let (s, r) = bounded(2);
let (mut a, mut b) = (false, false);
s.send(Msg2(&mut a)).unwrap();
s.send(Msg2(&mut b)).unwrap();
drop(s);
let res = std::panic::catch_unwind(move || {
drop(r);
});
assert_eq!(
*res.unwrap_err().downcast_ref::<&str>().unwrap(),
"first drop"
);
assert!(a);
// Elements after the panicked element will leak.
assert!(!b);
}
Loading

0 comments on commit aca2e31

Please sign in to comment.