Skip to content

Commit

Permalink
Rollup merge of rust-lang#124881 - Sp00ph:reentrant_lock_tid, r=joboet
Browse files Browse the repository at this point in the history
Use ThreadId instead of TLS-address in `ReentrantLock`

Fixes rust-lang#123458

`ReentrantLock` currently uses the address of a thread local variable as an ID that's unique across all currently running threads. This can lead to uninituitive behavior as in rust-lang#123458 if TLS blocks get reused. This PR changes `ReentrantLock` to instead use the `ThreadId` provided by `std` as the unique ID. `ThreadId` guarantees uniqueness across the lifetime of the whole process, so we don't need to worry about reusing IDs of terminated threads. The main appeal of this PR is thus the possibility of changing the `ReentrantLock` API to guarantee that if a thread leaks a lock guard, no other thread may ever acquire that lock again.

This does entail some complications:
- previously, the only way to retrieve the current thread ID would've been using `thread::current().id()` which creates a temporary `Arc` and which isn't available in TLS destructors. As part of this PR, the thread ID instead gets cached in its own thread local, as suggested [here](rust-lang#123458 (comment)).
- `ThreadId` is always 64-bit whereas the current implementation uses a usize-sized ID. Since this ID needs to be updated atomically, we can't simply use a single atomic variable on 32 bit platforms. Instead, we fall back to using a (sound) seqlock on 32-bit platforms, which works because only one thread at a time can write to the ID. This seqlock is technically susceptible to the ABA problem, but the attack vector to create actual unsoundness has to be very specific:
  - You would need to be able to lock+unlock the lock exactly 2^31 times (or a multiple thereof) while a thread trying to lock it sleeps
  - The sleeping thread would have to suspend after reading one half of the thread id but before reading the other half
  - The teared result from combining the halves of the thread ID would have to exactly line up with the sleeping thread's ID

The risk of this occurring seems slim enough to be acceptable to me, but correct me if I'm wrong. This also means that the size of the lock increases by 8 bytes on 32-bit platforms, but this also shouldn't be an issue.

Performance wise, I did some crude testing of the only case where this could lead to real slowdowns, which is the case of locking a `ReentrantLock` that's already locked by the current thread. On both aarch64 and x86-64, there is (expectedly) pretty much no performance hit. I didn't have any 32-bit platforms to test the seqlock performance on, so I did the next best thing and just forced the 64-bit platforms to use the seqlock implementation. There, the performance degraded by ~1-2ns/(lock+unlock) on x86-64 and ~6-8ns/(lock+unlock) on aarch64, which is measurable but seems acceptable to me seeing as 32-bit platforms should be a small minority anyways.

cc `@joboet` `@RalfJung` `@CAD97`
  • Loading branch information
matthiaskrgr authored Jul 18, 2024
2 parents cc4ed95 + 7e21850 commit b0c85ba
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 26 deletions.
138 changes: 115 additions & 23 deletions std/src/sync/reentrant_lock.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#[cfg(all(test, not(target_os = "emscripten")))]
mod tests;

use cfg_if::cfg_if;

use crate::cell::UnsafeCell;
use crate::fmt;
use crate::ops::Deref;
use crate::panic::{RefUnwindSafe, UnwindSafe};
use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use crate::sys::sync as sys;
use crate::thread::{current_id, ThreadId};

/// A re-entrant mutual exclusion lock
///
Expand Down Expand Up @@ -53,8 +55,8 @@ use crate::sys::sync as sys;
//
// The 'owner' field tracks which thread has locked the mutex.
//
// We use current_thread_unique_ptr() as the thread identifier,
// which is just the address of a thread local variable.
// We use thread::current_id() as the thread identifier, which is just the
// current thread's ThreadId, so it's unique across the process lifetime.
//
// If `owner` is set to the identifier of the current thread,
// we assume the mutex is already locked and instead of locking it again,
Expand All @@ -72,14 +74,109 @@ use crate::sys::sync as sys;
// since we're not dealing with multiple threads. If it's not equal,
// synchronization is left to the mutex, making relaxed memory ordering for
// the `owner` field fine in all cases.
//
// On systems without 64 bit atomics we also store the address of a TLS variable
// along the 64-bit TID. We then first check that address against the address
// of that variable on the current thread, and only if they compare equal do we
// compare the actual TIDs. Because we only ever read the TID on the same thread
// that it was written on (or a thread sharing the TLS block with that writer thread),
// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
// non-atomic accesses.
#[unstable(feature = "reentrant_lock", issue = "121440")]
pub struct ReentrantLock<T: ?Sized> {
mutex: sys::Mutex,
owner: AtomicUsize,
owner: Tid,
lock_count: UnsafeCell<u32>,
data: T,
}

cfg_if!(
if #[cfg(target_has_atomic = "64")] {
use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};

struct Tid(AtomicU64);

impl Tid {
const fn new() -> Self {
Self(AtomicU64::new(0))
}

#[inline]
fn contains(&self, owner: ThreadId) -> bool {
owner.as_u64().get() == self.0.load(Relaxed)
}

#[inline]
// This is just unsafe to match the API of the Tid type below.
unsafe fn set(&self, tid: Option<ThreadId>) {
let value = tid.map_or(0, |tid| tid.as_u64().get());
self.0.store(value, Relaxed);
}
}
} else {
/// Returns the address of a TLS variable. This is guaranteed to
/// be unique across all currently alive threads.
fn tls_addr() -> usize {
thread_local! { static X: u8 = const { 0u8 } };

X.with(|p| <*const u8>::addr(p))
}

use crate::sync::atomic::{
AtomicUsize,
Ordering,
};

struct Tid {
// When a thread calls `set()`, this value gets updated to
// the address of a thread local on that thread. This is
// used as a first check in `contains()`; if the `tls_addr`
// doesn't match the TLS address of the current thread, then
// the ThreadId also can't match. Only if the TLS addresses do
// match do we read out the actual TID.
// Note also that we can use relaxed atomic operations here, because
// we only ever read from the tid if `tls_addr` matches the current
// TLS address. In that case, either the the tid has been set by
// the current thread, or by a thread that has terminated before
// the current thread was created. In either case, no further
// synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
tls_addr: AtomicUsize,
tid: UnsafeCell<u64>,
}

unsafe impl Send for Tid {}
unsafe impl Sync for Tid {}

impl Tid {
const fn new() -> Self {
Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
}

#[inline]
// NOTE: This assumes that `owner` is the ID of the current
// thread, and may spuriously return `false` if that's not the case.
fn contains(&self, owner: ThreadId) -> bool {
// SAFETY: See the comments in the struct definition.
self.tls_addr.load(Ordering::Relaxed) == tls_addr()
&& unsafe { *self.tid.get() } == owner.as_u64().get()
}

#[inline]
// This may only be called by one thread at a time, and can lead to
// race conditions otherwise.
unsafe fn set(&self, tid: Option<ThreadId>) {
// It's important that we set `self.tls_addr` to 0 if the tid is
// cleared. Otherwise, there might be race conditions between
// `set()` and `get()`.
let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
let value = tid.map_or(0, |tid| tid.as_u64().get());
self.tls_addr.store(tls_addr, Ordering::Relaxed);
unsafe { *self.tid.get() = value };
}
}
}
);

#[unstable(feature = "reentrant_lock", issue = "121440")]
unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
#[unstable(feature = "reentrant_lock", issue = "121440")]
Expand Down Expand Up @@ -134,7 +231,7 @@ impl<T> ReentrantLock<T> {
pub const fn new(t: T) -> ReentrantLock<T> {
ReentrantLock {
mutex: sys::Mutex::new(),
owner: AtomicUsize::new(0),
owner: Tid::new(),
lock_count: UnsafeCell::new(0),
data: t,
}
Expand Down Expand Up @@ -184,14 +281,16 @@ impl<T: ?Sized> ReentrantLock<T> {
/// assert_eq!(lock.lock().get(), 10);
/// ```
pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
let this_thread = current_thread_unique_ptr();
// Safety: We only touch lock_count when we own the lock.
let this_thread = current_id();
// Safety: We only touch lock_count when we own the inner mutex.
// Additionally, we only call `self.owner.set()` while holding
// the inner mutex, so no two threads can call it concurrently.
unsafe {
if self.owner.load(Relaxed) == this_thread {
if self.owner.contains(this_thread) {
self.increment_lock_count().expect("lock count overflow in reentrant mutex");
} else {
self.mutex.lock();
self.owner.store(this_thread, Relaxed);
self.owner.set(Some(this_thread));
debug_assert_eq!(*self.lock_count.get(), 0);
*self.lock_count.get() = 1;
}
Expand Down Expand Up @@ -226,14 +325,16 @@ impl<T: ?Sized> ReentrantLock<T> {
///
/// This function does not block.
pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
let this_thread = current_thread_unique_ptr();
// Safety: We only touch lock_count when we own the lock.
let this_thread = current_id();
// Safety: We only touch lock_count when we own the inner mutex.
// Additionally, we only call `self.owner.set()` while holding
// the inner mutex, so no two threads can call it concurrently.
unsafe {
if self.owner.load(Relaxed) == this_thread {
if self.owner.contains(this_thread) {
self.increment_lock_count()?;
Some(ReentrantLockGuard { lock: self })
} else if self.mutex.try_lock() {
self.owner.store(this_thread, Relaxed);
self.owner.set(Some(this_thread));
debug_assert_eq!(*self.lock_count.get(), 0);
*self.lock_count.get() = 1;
Some(ReentrantLockGuard { lock: self })
Expand Down Expand Up @@ -308,18 +409,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
unsafe {
*self.lock.lock_count.get() -= 1;
if *self.lock.lock_count.get() == 0 {
self.lock.owner.store(0, Relaxed);
self.lock.owner.set(None);
self.lock.mutex.unlock();
}
}
}
}

/// Get an address that is unique per running thread.
///
/// This can be used as a non-null usize-sized ID.
pub(crate) fn current_thread_unique_ptr() -> usize {
// Use a non-drop type to make sure it's still available during thread destruction.
thread_local! { static X: u8 = const { 0 } }
X.with(|x| <*const _>::addr(x))
}
32 changes: 29 additions & 3 deletions std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
mod tests;

use crate::any::Any;
use crate::cell::{OnceCell, UnsafeCell};
use crate::cell::{Cell, OnceCell, UnsafeCell};
use crate::env;
use crate::ffi::{CStr, CString};
use crate::fmt;
Expand Down Expand Up @@ -699,17 +699,22 @@ where
}

thread_local! {
// Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together.
// If `CURRENT` is initialized, then `CURRENT_ID` will hold the same value
// as `CURRENT.id()`.
static CURRENT: OnceCell<Thread> = const { OnceCell::new() };
static CURRENT_ID: Cell<Option<ThreadId>> = const { Cell::new(None) };
}

/// Sets the thread handle for the current thread.
///
/// Aborts if the handle has been set already to reduce code size.
pub(crate) fn set_current(thread: Thread) {
let tid = thread.id();
// Using `unwrap` here can add ~3kB to the binary size. We have complete
// control over where this is called, so just abort if there is a bug.
CURRENT.with(|current| match current.set(thread) {
Ok(()) => {}
Ok(()) => CURRENT_ID.set(Some(tid)),
Err(_) => rtabort!("thread::set_current should only be called once per thread"),
});
}
Expand All @@ -719,7 +724,28 @@ pub(crate) fn set_current(thread: Thread) {
/// In contrast to the public `current` function, this will not panic if called
/// from inside a TLS destructor.
pub(crate) fn try_current() -> Option<Thread> {
CURRENT.try_with(|current| current.get_or_init(|| Thread::new_unnamed()).clone()).ok()
CURRENT
.try_with(|current| {
current
.get_or_init(|| {
let thread = Thread::new_unnamed();
CURRENT_ID.set(Some(thread.id()));
thread
})
.clone()
})
.ok()
}

/// Gets the id of the thread that invokes it.
#[inline]
pub(crate) fn current_id() -> ThreadId {
CURRENT_ID.get().unwrap_or_else(|| {
// If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized.
// `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to
// `current_id()` will succeed immediately.
current().id()
})
}

/// Gets a handle to the thread that invokes it.
Expand Down

0 comments on commit b0c85ba

Please sign in to comment.