Skip to content

Commit

Permalink
Implement faster thread local rng for scheduler (#55501)
Browse files Browse the repository at this point in the history
Implement optimal uniform random number generator using the method
proposed in swiftlang/swift#39143 based on
OpenSSL's implementation of it in
https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99

This PR also fixes some bugs found while developing it. This is a
replacement for #50203 and fixes
the issues found by @IanButterworth with both rngs

C rng
<img width="1011" alt="image"
src="https://github.com/user-attachments/assets/0dd9d5f2-17ef-4a70-b275-1d12692be060">

New scheduler rng
<img width="985" alt="image"
src="https://github.com/user-attachments/assets/4abd0a57-a1d9-46ec-99a5-535f366ecafa">

~On my benchmarks the julia implementation seems to be almost 50% faster
than the current implementation.~
With oscars suggestion of removing the debiasing this is now almost 5x
faster than the original implementation. And almost fully branchless

We might want to backport the two previous commits since they
technically fix bugs.

---------

Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
gbaraldi and vchuravy authored Sep 9, 2024
1 parent 5272dad commit 169e9e8
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 10 deletions.
55 changes: 54 additions & 1 deletion base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,60 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]


cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
"""
cong(max::UInt32)
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end



function multiq_sift_up(heap::taskheap, idx::Int32)
Expand Down
32 changes: 32 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ TRANSFORMED_CCALL_STAT(jl_cpu_wake);
TRANSFORMED_CCALL_STAT(jl_gc_safepoint);
TRANSFORMED_CCALL_STAT(jl_get_ptls_states);
TRANSFORMED_CCALL_STAT(jl_threadid);
TRANSFORMED_CCALL_STAT(jl_get_ptls_rng);
TRANSFORMED_CCALL_STAT(jl_set_ptls_rng);
TRANSFORMED_CCALL_STAT(jl_get_tls_world_age);
TRANSFORMED_CCALL_STAT(jl_get_world_counter);
TRANSFORMED_CCALL_STAT(jl_gc_enable_disable_finalizers_internal);
Expand Down Expand Up @@ -1692,6 +1694,36 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
ai.decorateInst(tid);
return mark_or_box_ccall_result(ctx, tid, retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(jl_get_ptls_rng)) {
++CCALL_STAT(jl_get_ptls_rng);
assert(lrt == getInt64Ty(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 0);
JL_GC_POP();
Value *ptls_p = get_current_ptls(ctx);
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
LoadInst *rng_value = ctx.builder.CreateAlignedLoad(getInt64Ty(ctx.builder.getContext()), rng_ptr, Align(sizeof(void*)));
setName(ctx.emission_context, rng_value, "rngseed");
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
ai.decorateInst(rng_value);
return mark_or_box_ccall_result(ctx, rng_value, retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(jl_set_ptls_rng)) {
++CCALL_STAT(jl_set_ptls_rng);
assert(lrt == getVoidTy(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 1);
JL_GC_POP();
Value *ptls_p = get_current_ptls(ctx);
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
assert(argv[0].V->getType() == getInt64Ty(ctx.builder.getContext()));
auto store = ctx.builder.CreateAlignedStore(argv[0].V, rng_ptr, Align(sizeof(void*)));
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
ai.decorateInst(store);
return ghostValue(ctx, jl_nothing_type);
}
else if (is_libjulia_func(jl_get_tls_world_age)) {
bool toplevel = !(ctx.linfo && jl_is_method(ctx.linfo->def.method));
if (!toplevel) { // top level code does not see a stable world age during execution
Expand Down
2 changes: 2 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
XX(jl_threadpoolid) \
XX(jl_get_ptls_rng) \
XX(jl_set_ptls_rng) \
XX(jl_throw) \
XX(jl_throw_out_of_memory_error) \
XX(jl_too_few_args) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ extern "C" {

JL_DLLEXPORT int16_t jl_threadid(void);
JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT;

// JULIA_ENABLE_THREADING may be controlled by altering JULIA_THREADS in Make.user

Expand Down
9 changes: 0 additions & 9 deletions src/scheduler.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,6 @@ JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSA
extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache,
jl_gc_markqueue_t *mq, jl_value_t *obj) JL_NOTSAFEPOINT;

// parallel task runtime
// ---

JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n)
{
jl_ptls_t ptls = jl_current_task->ptls;
return cong(max, &ptls->rngseed);
}

// initialize the threading infrastructure
// (called only by the main thread)
void jl_init_threadinginfra(void)
Expand Down
12 changes: 12 additions & 0 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT
return -1; // everything else uses threadpool -1 (does not belong to any threadpool)
}

// get thread local rng
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT
{
return jl_current_task->ptls->rngseed;
}

// get thread local rng
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT
{
jl_current_task->ptls->rngseed = new_seed;
}

jl_ptls_t jl_init_threadtls(int16_t tid)
{
#ifndef _OS_WINDOWS_
Expand Down

0 comments on commit 169e9e8

Please sign in to comment.