Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement faster thread local rng for scheduler #55501

Merged
merged 9 commits into from
Sep 9, 2024
51 changes: 50 additions & 1 deletion base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,56 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]


cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
"""
cong(max::UInt32)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved

const rngseed_offset = unsafe_load(cglobal(:jl_ptls_rng_offset, Cint))

"""
jl_rand_ptls(max::UInt32)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function jl_rand_ptls(max::UInt32)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls())
rngseed = Base.unsafe_load(ptls + rngseed_offset)
val, seed = rand_uniform_max_int32(max, rngseed)
Base.unsafe_store!(ptls + rngseed_offset, seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end



function multiq_sift_up(heap::taskheap, idx::Int32)
Expand Down
1 change: 1 addition & 0 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel)
// Important offset for external codegen.
jl_task_gcstack_offset = offsetof(jl_task_t, gcstack);
jl_task_ptls_offset = offsetof(jl_task_t, ptls);
jl_ptls_rng_offset = offsetof(jl_tls_states_t, rngseed);

jl_prep_sanitizers();
void *stack_lo, *stack_hi;
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,6 @@
XX(jl_options, jl_options_t) \
XX(jl_task_gcstack_offset, int) \
XX(jl_task_ptls_offset, int) \
XX(jl_ptls_rng_offset, int) \

// end of file
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,7 @@ JL_DLLEXPORT JL_CONST_FUNC jl_gcframe_t **(jl_get_pgcstack)(void) JL_GLOBALLY_RO

extern JL_DLLIMPORT int jl_task_gcstack_offset;
extern JL_DLLIMPORT int jl_task_ptls_offset;
extern JL_DLLIMPORT int jl_ptls_rng_offset;

#include "julia_locks.h" // requires jl_task_t definition

Expand Down