Skip to content

Commit

Permalink
std: Add a variant of thread locals with const init
Browse files Browse the repository at this point in the history
This commit adds a variant of the `thread_local!` macro as a new
`thread_local_const_init!` macro which requires that the initialization
expression is constant (e.g. could be stuck into a `const` if so
desired). This form of thread local allows for a more efficient
implementation of `LocalKey::with` both if the value has a destructor
and if it doesn't. If the value doesn't have a destructor then `with`
should desugar to exactly as-if you use `#[thread_local]` given
sufficient inlining.

The purpose of this new form of thread locals is to precisely be
equivalent to `#[thread_local]` on platforms where possible for values
which fit the bill (those without destructors). This should help close
the gap in performance between `thread_local!`, which is safe, relative
to `#[thread_local]`, which is not easy to use in a portable fashion.
  • Loading branch information
alexcrichton committed Apr 16, 2021
1 parent f1ca558 commit c6eea22
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 49 deletions.
5 changes: 4 additions & 1 deletion library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@
// std may use features in a platform-specific way
#![allow(unused_features)]
#![feature(rustc_allow_const_fn_unstable)]
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count))]
#![cfg_attr(
test,
feature(internal_output_capture, print_internals, update_panic_count, thread_local_const_init)
)]
#![cfg_attr(
all(target_vendor = "fortanix", target_env = "sgx"),
feature(slice_index_methods, coerce_unsized, sgx_platform)
Expand Down
117 changes: 115 additions & 2 deletions library/std/src/thread/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ macro_rules! thread_local {
// empty (base case for the recursion)
() => {};

($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = const { $init:expr }; $($rest:tt)*) => (
$crate::__thread_local_inner!($(#[$attr])* $vis $name, $t, const $init);
$crate::thread_local!($($rest)*);
);

($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = const { $init:expr }) => (
$crate::__thread_local_inner!($(#[$attr])* $vis $name, $t, const $init);
);

// process multiple declarations
($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr; $($rest:tt)*) => (
$crate::__thread_local_inner!($(#[$attr])* $vis $name, $t, $init);
Expand All @@ -151,6 +160,101 @@ macro_rules! thread_local {
#[allow_internal_unstable(thread_local_internals, cfg_target_thread_local, thread_local)]
#[allow_internal_unsafe]
macro_rules! __thread_local_inner {
// used to generate the `LocalKey` value for const-initialized thread locals
(@key $t:ty, const $init:expr) => {{
unsafe fn __getit() -> $crate::option::Option<&'static $t> {
const _REQUIRE_UNSTABLE: () = $crate::thread::require_unstable_const_init_thread_local();

// wasm without atomics maps directly to `static mut`, and dtors
// aren't implemented because thread dtors aren't really a thing
// on wasm right now
//
// FIXME(#84224) this should come after the `target_thread_local`
// block.
#[cfg(all(target_arch = "wasm32", not(target_feature = "atomics")))]
{
static mut VAL: $t = $init;
Some(&VAL)
}

// If the platform has support for `#[thread_local]`, use it.
#[cfg(all(
target_thread_local,
not(all(target_arch = "wasm32", not(target_feature = "atomics"))),
))]
{
// If a dtor isn't needed we can do something "very raw" and
// just get going.
if !$crate::mem::needs_drop::<$t>() {
#[thread_local]
static mut VAL: $t = $init;
unsafe {
return Some(&VAL)
}
}

#[thread_local]
static mut VAL: $t = $init;
// 0 == dtor not registered
// 1 == dtor registered, dtor not run
// 2 == dtor registered and is running or has run
#[thread_local]
static mut STATE: u8 = 0;

unsafe extern "C" fn destroy(ptr: *mut u8) {
let ptr = ptr as *mut $t;

unsafe {
debug_assert_eq!(STATE, 1);
STATE = 2;
$crate::ptr::drop_in_place(ptr);
}
}

unsafe {
match STATE {
// 0 == we haven't registered a destructor, so do
// so now.
0 => {
$crate::thread::__FastLocalKeyInner::<$t>::register_dtor(
&VAL as *const _ as *mut u8,
destroy,
);
STATE = 1;
Some(&VAL)
}
// 1 == the destructor is registered and the value
// is valid, so return the pointer.
1 => Some(&VAL),
// otherwise the destructor has already run, so we
// can't give access.
_ => None,
}
}
}

// On platforms without `#[thread_local]` we fall back to the
// same implementation as below for os thread locals.
#[cfg(all(
not(target_thread_local),
not(all(target_arch = "wasm32", not(target_feature = "atomics"))),
))]
{
#[inline]
const fn __init() -> $t { $init }
static __KEY: $crate::thread::__OsLocalKeyInner<$t> =
$crate::thread::__OsLocalKeyInner::new();
#[allow(unused_unsafe)]
unsafe { __KEY.get(__init) }
}
}

unsafe {
$crate::thread::LocalKey::new(__getit)
}
}};

// used to generate the `LocalKey` value for `thread_local!`
(@key $t:ty, $init:expr) => {
{
#[inline]
Expand Down Expand Up @@ -188,9 +292,9 @@ macro_rules! __thread_local_inner {
}
}
};
($(#[$attr:meta])* $vis:vis $name:ident, $t:ty, $init:expr) => {
($(#[$attr:meta])* $vis:vis $name:ident, $t:ty, $($init:tt)*) => {
$(#[$attr])* $vis const $name: $crate::thread::LocalKey<$t> =
$crate::__thread_local_inner!(@key $t, $init);
$crate::__thread_local_inner!(@key $t, $($init)*);
}
}

Expand Down Expand Up @@ -442,6 +546,15 @@ pub mod fast {
Key { inner: LazyKeyInner::new(), dtor_state: Cell::new(DtorState::Unregistered) }
}

// note that this is just a publically-callable function only for the
// const-initialized form of thread locals, basically a way to call the
// free `register_dtor` function defined elsewhere in libstd.
pub unsafe fn register_dtor(a: *mut u8, dtor: unsafe extern "C" fn(*mut u8)) {
unsafe {
register_dtor(a, dtor);
}
}

pub unsafe fn get<F: FnOnce() -> T>(&self, init: F) -> Option<&'static T> {
// SAFETY: See the definitions of `LazyKeyInner::get` and
// `try_initialize` for more informations.
Expand Down
147 changes: 101 additions & 46 deletions library/std/src/thread/local/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::cell::{Cell, UnsafeCell};
use crate::sync::mpsc::{channel, Sender};
use crate::thread;
use crate::thread::{self, LocalKey};
use crate::thread_local;

struct Foo(Sender<()>);
Expand All @@ -15,74 +15,90 @@ impl Drop for Foo {
#[test]
fn smoke_no_dtor() {
thread_local!(static FOO: Cell<i32> = Cell::new(1));
run(&FOO);
thread_local!(static FOO2: Cell<i32> = const { Cell::new(1) });
run(&FOO2);

FOO.with(|f| {
assert_eq!(f.get(), 1);
f.set(2);
});
let (tx, rx) = channel();
let _t = thread::spawn(move || {
FOO.with(|f| {
fn run(key: &'static LocalKey<Cell<i32>>) {
key.with(|f| {
assert_eq!(f.get(), 1);
f.set(2);
});
tx.send(()).unwrap();
});
rx.recv().unwrap();
let t = thread::spawn(move || {
key.with(|f| {
assert_eq!(f.get(), 1);
});
});
t.join().unwrap();

FOO.with(|f| {
assert_eq!(f.get(), 2);
});
key.with(|f| {
assert_eq!(f.get(), 2);
});
}
}

#[test]
fn states() {
struct Foo;
struct Foo(&'static LocalKey<Foo>);
impl Drop for Foo {
fn drop(&mut self) {
assert!(FOO.try_with(|_| ()).is_err());
assert!(self.0.try_with(|_| ()).is_err());
}
}
thread_local!(static FOO: Foo = Foo);

thread::spawn(|| {
assert!(FOO.try_with(|_| ()).is_ok());
})
.join()
.ok()
.expect("thread panicked");
thread_local!(static FOO: Foo = Foo(&FOO));
run(&FOO);
thread_local!(static FOO2: Foo = const { Foo(&FOO2) });
run(&FOO2);

fn run(foo: &'static LocalKey<Foo>) {
thread::spawn(move || {
assert!(foo.try_with(|_| ()).is_ok());
})
.join()
.unwrap();
}
}

#[test]
fn smoke_dtor() {
thread_local!(static FOO: UnsafeCell<Option<Foo>> = UnsafeCell::new(None));

let (tx, rx) = channel();
let _t = thread::spawn(move || unsafe {
let mut tx = Some(tx);
FOO.with(|f| {
*f.get() = Some(Foo(tx.take().unwrap()));
run(&FOO);
thread_local!(static FOO2: UnsafeCell<Option<Foo>> = const { UnsafeCell::new(None) });
run(&FOO2);

fn run(key: &'static LocalKey<UnsafeCell<Option<Foo>>>) {
let (tx, rx) = channel();
let t = thread::spawn(move || unsafe {
let mut tx = Some(tx);
key.with(|f| {
*f.get() = Some(Foo(tx.take().unwrap()));
});
});
});
rx.recv().unwrap();
rx.recv().unwrap();
t.join().unwrap();
}
}

#[test]
fn circular() {
struct S1;
struct S2;
struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
struct S2(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
thread_local!(static K2: UnsafeCell<Option<S2>> = UnsafeCell::new(None));
static mut HITS: u32 = 0;
thread_local!(static K3: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
thread_local!(static K4: UnsafeCell<Option<S2>> = const { UnsafeCell::new(None) });
static mut HITS: usize = 0;

impl Drop for S1 {
fn drop(&mut self) {
unsafe {
HITS += 1;
if K2.try_with(|_| ()).is_err() {
if self.1.try_with(|_| ()).is_err() {
assert_eq!(HITS, 3);
} else {
if HITS == 1 {
K2.with(|s| *s.get() = Some(S2));
self.1.with(|s| *s.get() = Some(S2(self.0, self.1)));
} else {
assert_eq!(HITS, 3);
}
Expand All @@ -94,38 +110,54 @@ fn circular() {
fn drop(&mut self) {
unsafe {
HITS += 1;
assert!(K1.try_with(|_| ()).is_ok());
assert!(self.0.try_with(|_| ()).is_ok());
assert_eq!(HITS, 2);
K1.with(|s| *s.get() = Some(S1));
self.0.with(|s| *s.get() = Some(S1(self.0, self.1)));
}
}
}

thread::spawn(move || {
drop(S1);
drop(S1(&K1, &K2));
})
.join()
.unwrap();

unsafe {
HITS = 0;
}

thread::spawn(move || {
drop(S1(&K3, &K4));
})
.join()
.ok()
.expect("thread panicked");
.unwrap();
}

#[test]
fn self_referential() {
struct S1;
struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>);

thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
thread_local!(static K2: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });

impl Drop for S1 {
fn drop(&mut self) {
assert!(K1.try_with(|_| ()).is_err());
assert!(self.0.try_with(|_| ()).is_err());
}
}

thread::spawn(move || unsafe {
K1.with(|s| *s.get() = Some(S1));
K1.with(|s| *s.get() = Some(S1(&K1)));
})
.join()
.ok()
.expect("thread panicked");
.unwrap();

thread::spawn(move || unsafe {
K2.with(|s| *s.get() = Some(S1(&K2)));
})
.join()
.unwrap();
}

// Note that this test will deadlock if TLS destructors aren't run (this
Expand All @@ -152,3 +184,26 @@ fn dtors_in_dtors_in_dtors() {
});
rx.recv().unwrap();
}

#[test]
fn dtors_in_dtors_in_dtors_const_init() {
struct S1(Sender<()>);
thread_local!(static K1: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
thread_local!(static K2: UnsafeCell<Option<Foo>> = const { UnsafeCell::new(None) });

impl Drop for S1 {
fn drop(&mut self) {
let S1(ref tx) = *self;
unsafe {
let _ = K2.try_with(|s| *s.get() = Some(Foo(tx.clone())));
}
}
}

let (tx, rx) = channel();
let _t = thread::spawn(move || unsafe {
let mut tx = Some(tx);
K1.with(|s| *s.get() = Some(S1(tx.take().unwrap())));
});
rx.recv().unwrap();
}
7 changes: 7 additions & 0 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ pub use self::local::os::Key as __OsLocalKeyInner;
#[doc(hidden)]
pub use self::local::statik::Key as __StaticLocalKeyInner;

// This is only used to make thread locals with `const { .. }` initialization
// expressions unstable. If and/or when that syntax is stabilized with thread
// locals this will simply be removed.
#[doc(hidden)]
#[unstable(feature = "thread_local_const_init", issue = "84223")]
pub const fn require_unstable_const_init_thread_local() {}

////////////////////////////////////////////////////////////////////////////////
// Builder
////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit c6eea22

Please sign in to comment.