Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure TLS destructors run before thread joins in SGX #84409

Merged
merged 3 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions library/std/src/sys/sgx/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) {
extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn {
// FIXME: how to support TLS in library mode?
let tls = Box::new(tls::Tls::new());
let _tls_guard = unsafe { tls.activate() };
let tls_guard = unsafe { tls.activate() };

if secondary {
super::thread::Thread::entry();
let join_notifier = super::thread::Thread::entry();
drop(tls_guard);
drop(join_notifier);

EntryReturn(0, 0)
} else {
Expand Down
69 changes: 61 additions & 8 deletions library/std/src/sys/sgx/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle);

pub const DEFAULT_MIN_STACK_SIZE: usize = 4096;

pub use self::task_queue::JoinNotifier;

mod task_queue {
use crate::sync::mpsc;
use super::wait_notify;
use crate::sync::{Mutex, MutexGuard, Once};

pub type JoinHandle = mpsc::Receiver<()>;
pub type JoinHandle = wait_notify::Waiter;

pub struct JoinNotifier(Option<wait_notify::Notifier>);

impl Drop for JoinNotifier {
fn drop(&mut self) {
self.0.take().unwrap().notify();
}
}

pub(super) struct Task {
p: Box<dyn FnOnce()>,
done: mpsc::Sender<()>,
done: JoinNotifier,
}

impl Task {
pub(super) fn new(p: Box<dyn FnOnce()>) -> (Task, JoinHandle) {
let (done, recv) = mpsc::channel();
let (done, recv) = wait_notify::new();
let done = JoinNotifier(Some(done));
(Task { p, done }, recv)
}

pub(super) fn run(self) {
pub(super) fn run(self) -> JoinNotifier {
(self.p)();
let _ = self.done.send(());
self.done
}
}

Expand All @@ -47,6 +58,48 @@ mod task_queue {
}
}

/// This module provides a synchronization primitive that does not use thread
/// local variables. This is needed for signaling that a thread has finished
/// execution. The signal is sent once all TLS destructors have finished at
/// which point no new thread locals should be created.
pub mod wait_notify {
use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable};
use crate::sync::Arc;

pub struct Notifier(Arc<SpinMutex<WaitVariable<bool>>>);

impl Notifier {
/// Notify the waiter. The waiter is either notified right away (if
/// currently blocked in `Waiter::wait()`) or later when it calls the
/// `Waiter::wait()` method.
pub fn notify(self) {
let mut guard = self.0.lock();
*guard.lock_var_mut() = true;
let _ = WaitQueue::notify_one(guard);
}
}

pub struct Waiter(Arc<SpinMutex<WaitVariable<bool>>>);

impl Waiter {
/// Wait for a notification. If `Notifier::notify()` has already been
/// called, this will return immediately, otherwise the current thread
/// is blocked until notified.
pub fn wait(self) {
let guard = self.0.lock();
if *guard.lock_var() {
return;
}
WaitQueue::wait(guard, || {});
}
}

pub fn new() -> (Notifier, Waiter) {
let inner = Arc::new(SpinMutex::new(WaitVariable::new(false)));
(Notifier(inner.clone()), Waiter(inner))
}
}

impl Thread {
// unsafe: see thread::Builder::spawn_unchecked for safety requirements
pub unsafe fn new(_stack: usize, p: Box<dyn FnOnce()>) -> io::Result<Thread> {
Expand All @@ -57,7 +110,7 @@ impl Thread {
Ok(Thread(handle))
}

pub(super) fn entry() {
pub(super) fn entry() -> JoinNotifier {
let mut pending_tasks = task_queue::lock();
let task = rtunwrap!(Some, pending_tasks.pop());
drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary
Expand All @@ -78,7 +131,7 @@ impl Thread {
}

pub fn join(self) {
let _ = self.0.recv();
self.0.wait();
}
}

Expand Down
108 changes: 108 additions & 0 deletions library/std/src/thread/local/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::cell::{Cell, UnsafeCell};
use crate::sync::atomic::{AtomicU8, Ordering};
use crate::sync::mpsc::{channel, Sender};
use crate::thread::{self, LocalKey};
use crate::thread_local;
Expand Down Expand Up @@ -207,3 +208,110 @@ fn dtors_in_dtors_in_dtors_const_init() {
});
rx.recv().unwrap();
}

// This test tests that TLS destructors have run before the thread joins. The
// test has no false positives (meaning: if the test fails, there's actually
// an ordering problem). It may have false negatives, where the test passes but
// join is not guaranteed to be after the TLS destructors. However, false
// negatives should be exceedingly rare due to judicious use of
// thread::yield_now and running the test several times.
#[test]
fn join_orders_after_tls_destructors() {
// We emulate a synchronous MPSC rendezvous channel using only atomics and
// thread::yield_now. We can't use std::mpsc as the implementation itself
// may rely on thread locals.
//
// The basic state machine for an SPSC rendezvous channel is:
// FRESH -> THREAD1_WAITING -> MAIN_THREAD_RENDEZVOUS
// where the first transition is done by the “receiving” thread and the 2nd
// transition is done by the “sending” thread.
//
// We add an additional state `THREAD2_LAUNCHED` between `FRESH` and
// `THREAD1_WAITING` to block until all threads are actually running.
//
// A thread that joins on the “receiving” thread completion should never
// observe the channel in the `THREAD1_WAITING` state. If this does occur,
// we switch to the “poison” state `THREAD2_JOINED` and panic all around.
// (This is equivalent to “sending” from an alternate producer thread.)
const FRESH: u8 = 0;
const THREAD2_LAUNCHED: u8 = 1;
const THREAD1_WAITING: u8 = 2;
const MAIN_THREAD_RENDEZVOUS: u8 = 3;
const THREAD2_JOINED: u8 = 4;
static SYNC_STATE: AtomicU8 = AtomicU8::new(FRESH);

for _ in 0..10 {
SYNC_STATE.store(FRESH, Ordering::SeqCst);

let jh = thread::Builder::new()
.name("thread1".into())
.spawn(move || {
struct TlDrop;

impl Drop for TlDrop {
fn drop(&mut self) {
let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst);
loop {
match sync_state {
THREAD2_LAUNCHED | THREAD1_WAITING => thread::yield_now(),
MAIN_THREAD_RENDEZVOUS => break,
THREAD2_JOINED => panic!(
"Thread 1 still running after thread 2 joined on thread 1"
),
v => unreachable!("sync state: {}", v),
}
sync_state = SYNC_STATE.load(Ordering::SeqCst);
}
}
}

thread_local! {
static TL_DROP: TlDrop = TlDrop;
}

TL_DROP.with(|_| {});

loop {
match SYNC_STATE.load(Ordering::SeqCst) {
FRESH => thread::yield_now(),
THREAD2_LAUNCHED => break,
v => unreachable!("sync state: {}", v),
}
}
})
.unwrap();

let jh2 = thread::Builder::new()
.name("thread2".into())
.spawn(move || {
assert_eq!(SYNC_STATE.swap(THREAD2_LAUNCHED, Ordering::SeqCst), FRESH);
jh.join().unwrap();
match SYNC_STATE.swap(THREAD2_JOINED, Ordering::SeqCst) {
MAIN_THREAD_RENDEZVOUS => return,
THREAD2_LAUNCHED | THREAD1_WAITING => {
panic!("Thread 2 running after thread 1 join before main thread rendezvous")
}
v => unreachable!("sync state: {:?}", v),
}
})
.unwrap();

loop {
match SYNC_STATE.compare_exchange_weak(
THREAD1_WAITING,
MAIN_THREAD_RENDEZVOUS,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(FRESH) => thread::yield_now(),
Err(THREAD2_LAUNCHED) => thread::yield_now(),
Err(THREAD2_JOINED) => {
panic!("Main thread rendezvous after thread 2 joined thread 1")
}
v => unreachable!("sync state: {:?}", v),
}
Comment on lines +300 to +313
Copy link
Contributor

@tmiasko tmiasko Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_exchange_weak allows spurious failures. If the current value is THREAD1_WAITING but operation fails, the code will enter last unreachable branch and panic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We observed this failure in #89584 on armhf Debian, I have started a re-try of the build to see if the error goes away. But yes it should be fixed so it never occurs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already fixed in #86383 so you must've observed a different failure.

}
jh2.join().unwrap();
}
}