From bafc3ad22ebde70f8f03d7ca122253168423024d Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Sat, 22 Jan 2022 13:50:04 +0900 Subject: [PATCH] Reduce unsafe code in array queue/bounded channel --- crossbeam-channel/src/flavors/array.rs | 53 ++++++---------- crossbeam-channel/tests/array.rs | 86 +++++++++++++++++++++++++- crossbeam-queue/src/array_queue.rs | 59 +++++++----------- 3 files changed, 123 insertions(+), 75 deletions(-) diff --git a/crossbeam-channel/src/flavors/array.rs b/crossbeam-channel/src/flavors/array.rs index c8d0bcde3..dc4a7bf38 100644 --- a/crossbeam-channel/src/flavors/array.rs +++ b/crossbeam-channel/src/flavors/array.rs @@ -9,7 +9,6 @@ //! - use std::cell::UnsafeCell; -use std::marker::PhantomData; use std::mem::MaybeUninit; use std::ptr; use std::sync::atomic::{self, AtomicUsize, Ordering}; @@ -72,7 +71,7 @@ pub(crate) struct Channel { tail: CachePadded, /// The buffer holding slots. - buffer: *mut Slot, + buffer: Box<[Slot]>, /// The channel capacity. cap: usize, @@ -88,9 +87,6 @@ pub(crate) struct Channel { /// Receivers waiting while the channel is empty and not disconnected. receivers: SyncWaker, - - /// Indicates that dropping a `Channel` may drop values of type `T`. - _marker: PhantomData, } impl Channel { @@ -109,18 +105,15 @@ impl Channel { // Allocate a buffer of `cap` slots initialized // with stamps. - let buffer = { - let boxed: Box<[Slot]> = (0..cap) - .map(|i| { - // Set the stamp to `{ lap: 0, mark: 0, index: i }`. - Slot { - stamp: AtomicUsize::new(i), - msg: UnsafeCell::new(MaybeUninit::uninit()), - } - }) - .collect(); - Box::into_raw(boxed) as *mut Slot - }; + let buffer: Box<[Slot]> = (0..cap) + .map(|i| { + // Set the stamp to `{ lap: 0, mark: 0, index: i }`. + Slot { + stamp: AtomicUsize::new(i), + msg: UnsafeCell::new(MaybeUninit::uninit()), + } + }) + .collect(); Channel { buffer, @@ -131,7 +124,6 @@ impl Channel { tail: CachePadded::new(AtomicUsize::new(tail)), senders: SyncWaker::new(), receivers: SyncWaker::new(), - _marker: PhantomData, } } @@ -163,7 +155,8 @@ impl Channel { let lap = tail & !(self.one_lap - 1); // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; + debug_assert!(index < self.buffer.len()); + let slot = unsafe { self.buffer.get_unchecked(index) }; let stamp = slot.stamp.load(Ordering::Acquire); // If the tail and the stamp match, we may attempt to push. @@ -245,7 +238,8 @@ impl Channel { let lap = head & !(self.one_lap - 1); // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; + debug_assert!(index < self.buffer.len()); + let slot = unsafe { self.buffer.get_unchecked(index) }; let stamp = slot.stamp.load(Ordering::Acquire); // If the the stamp is ahead of the head by 1, we may attempt to pop. @@ -540,23 +534,12 @@ impl Drop for Channel { }; unsafe { - let p = { - let slot = &mut *self.buffer.add(index); - let msg = &mut *slot.msg.get(); - msg.as_mut_ptr() - }; - p.drop_in_place(); + debug_assert!(index < self.buffer.len()); + let slot = self.buffer.get_unchecked_mut(index); + let msg = &mut *slot.msg.get(); + msg.as_mut_ptr().drop_in_place(); } } - - // Finally, deallocate the buffer, but don't run any destructors. - unsafe { - // Create a slice from the buffer to make - // a fat pointer. Then, use Box::from_raw - // to deallocate it. - let ptr = std::slice::from_raw_parts_mut(self.buffer, self.cap) as *mut [Slot]; - Box::from_raw(ptr); - } } } diff --git a/crossbeam-channel/tests/array.rs b/crossbeam-channel/tests/array.rs index bb2cebe88..7c4fc4dff 100644 --- a/crossbeam-channel/tests/array.rs +++ b/crossbeam-channel/tests/array.rs @@ -1,7 +1,5 @@ //! Tests for the array channel flavor. -#![cfg(not(miri))] // TODO: many assertions failed due to Miri is slow - use std::any::Any; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; @@ -254,7 +252,13 @@ fn recv_after_disconnect() { #[test] fn len() { + #[cfg(miri)] + const COUNT: usize = 250; + #[cfg(not(miri))] const COUNT: usize = 25_000; + #[cfg(miri)] + const CAP: usize = 100; + #[cfg(not(miri))] const CAP: usize = 1000; let (s, r) = bounded(CAP); @@ -347,6 +351,9 @@ fn disconnect_wakes_receiver() { #[test] fn spsc() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 100_000; let (s, r) = bounded(3); @@ -369,6 +376,9 @@ fn spsc() { #[test] fn mpmc() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 25_000; const THREADS: usize = 4; @@ -401,6 +411,9 @@ fn mpmc() { #[test] fn stress_oneshot() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 10_000; for _ in 0..COUNT { @@ -416,6 +429,9 @@ fn stress_oneshot() { #[test] fn stress_iter() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 100_000; let (request_s, request_r) = bounded(1); @@ -481,6 +497,7 @@ fn stress_timeout_two_threads() { .unwrap(); } +#[cfg_attr(miri, ignore)] // Miri is too slow #[test] fn drops() { const RUNS: usize = 100; @@ -533,6 +550,9 @@ fn drops() { #[test] fn linearizable() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 25_000; const THREADS: usize = 4; @@ -553,6 +573,9 @@ fn linearizable() { #[test] fn fairness() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 10_000; let (s1, r1) = bounded::<()>(COUNT); @@ -575,6 +598,9 @@ fn fairness() { #[test] fn fairness_duplicates() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 10_000; let (s, r) = bounded::<()>(COUNT); @@ -619,6 +645,9 @@ fn recv_in_send() { #[test] fn channel_through_channel() { + #[cfg(miri)] + const COUNT: usize = 100; + #[cfg(not(miri))] const COUNT: usize = 1000; type T = Box; @@ -654,3 +683,56 @@ fn channel_through_channel() { }) .unwrap(); } + +#[test] +fn panic_on_drop() { + struct Msg1<'a>(&'a mut bool); + impl Drop for Msg1<'_> { + fn drop(&mut self) { + if *self.0 && !std::thread::panicking() { + panic!("double drop"); + } else { + *self.0 = true; + } + } + } + + struct Msg2<'a>(&'a mut bool); + impl Drop for Msg2<'_> { + fn drop(&mut self) { + if *self.0 { + panic!("double drop"); + } else { + *self.0 = true; + panic!("first drop"); + } + } + } + + // normal + let (s, r) = bounded(2); + let (mut a, mut b) = (false, false); + s.send(Msg1(&mut a)).unwrap(); + s.send(Msg1(&mut b)).unwrap(); + drop(s); + drop(r); + assert!(a); + assert!(b); + + // panic on drop + let (s, r) = bounded(2); + let (mut a, mut b) = (false, false); + s.send(Msg2(&mut a)).unwrap(); + s.send(Msg2(&mut b)).unwrap(); + drop(s); + let res = std::panic::catch_unwind(move || { + drop(r); + }); + assert_eq!( + *res.unwrap_err().downcast_ref::<&str>().unwrap(), + "first drop" + ); + assert!(a); + // Elements after the panicked element will leak. + assert!(!b); +} diff --git a/crossbeam-queue/src/array_queue.rs b/crossbeam-queue/src/array_queue.rs index 048767f28..5f3061b70 100644 --- a/crossbeam-queue/src/array_queue.rs +++ b/crossbeam-queue/src/array_queue.rs @@ -6,9 +6,7 @@ use alloc::boxed::Box; use core::cell::UnsafeCell; use core::fmt; -use core::marker::PhantomData; use core::mem::MaybeUninit; -use core::ptr; use core::sync::atomic::{self, AtomicUsize, Ordering}; use crossbeam_utils::{Backoff, CachePadded}; @@ -64,16 +62,13 @@ pub struct ArrayQueue { tail: CachePadded, /// The buffer holding slots. - buffer: *mut Slot, + buffer: Box<[Slot]>, /// The queue capacity. cap: usize, /// A stamp with the value of `{ lap: 1, index: 0 }`. one_lap: usize, - - /// Indicates that dropping an `ArrayQueue` may drop elements of type `T`. - _marker: PhantomData, } unsafe impl Sync for ArrayQueue {} @@ -103,18 +98,15 @@ impl ArrayQueue { // Allocate a buffer of `cap` slots initialized // with stamps. - let buffer = { - let boxed: Box<[Slot]> = (0..cap) - .map(|i| { - // Set the stamp to `{ lap: 0, index: i }`. - Slot { - stamp: AtomicUsize::new(i), - value: UnsafeCell::new(MaybeUninit::uninit()), - } - }) - .collect(); - Box::into_raw(boxed) as *mut Slot - }; + let buffer: Box<[Slot]> = (0..cap) + .map(|i| { + // Set the stamp to `{ lap: 0, index: i }`. + Slot { + stamp: AtomicUsize::new(i), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + }) + .collect(); // One lap is the smallest power of two greater than `cap`. let one_lap = (cap + 1).next_power_of_two(); @@ -125,7 +117,6 @@ impl ArrayQueue { one_lap, head: CachePadded::new(AtomicUsize::new(head)), tail: CachePadded::new(AtomicUsize::new(tail)), - _marker: PhantomData, } } @@ -153,7 +144,8 @@ impl ArrayQueue { let lap = tail & !(self.one_lap - 1); // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; + debug_assert!(index < self.buffer.len()); + let slot = unsafe { self.buffer.get_unchecked(index) }; let stamp = slot.stamp.load(Ordering::Acquire); // If the tail and the stamp match, we may attempt to push. @@ -233,7 +225,8 @@ impl ArrayQueue { let lap = head & !(self.one_lap - 1); // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; + debug_assert!(index < self.buffer.len()); + let slot = unsafe { self.buffer.get_unchecked(index) }; let stamp = slot.stamp.load(Ordering::Acquire); // If the the stamp is ahead of the head by 1, we may attempt to pop. @@ -406,23 +399,12 @@ impl Drop for ArrayQueue { }; unsafe { - let p = { - let slot = &mut *self.buffer.add(index); - let value = &mut *slot.value.get(); - value.as_mut_ptr() - }; - p.drop_in_place(); + debug_assert!(index < self.buffer.len()); + let slot = self.buffer.get_unchecked_mut(index); + let value = &mut *slot.value.get(); + value.as_mut_ptr().drop_in_place(); } } - - // Finally, deallocate the buffer, but don't run any destructors. - unsafe { - // Create a slice from the buffer to make - // a fat pointer. Then, use Box::from_raw - // to deallocate it. - let ptr = core::slice::from_raw_parts_mut(self.buffer, self.cap) as *mut [Slot]; - Box::from_raw(ptr); - } } } @@ -461,8 +443,9 @@ impl Iterator for IntoIter { // initialized because it is the value pointed at by `value.head` // and this is a non-empty queue. let val = unsafe { - let slot = &mut *value.buffer.add(index); - ptr::read(slot.value.get()).assume_init() + debug_assert!(index < value.buffer.len()); + let slot = value.buffer.get_unchecked_mut(index); + slot.value.get().read().assume_init() }; let new = if index + 1 < value.cap { // Same lap, incremented index.