Skip to content

Commit

Permalink
Implement optimal uniform random number generator using the method pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaraldi committed Aug 15, 2024
1 parent d84d3ad commit 7223ddf
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
88 changes: 87 additions & 1 deletion base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,95 @@ const heap_d = UInt32(8)
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]

"""
cong(max::UInt32)
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check


"""
jl_rand_ptls(max::UInt32)
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
function jl_rand_ptls(max::UInt32)
ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls())
rngseed = Base.unsafe_load(ptls, 2)
val, seed = rand_uniform_max_int32(max, rngseed)
Base.unsafe_store!(ptls, seed, 2)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implemantation as well.

Check warning on line 46 in base/partr.jl

View workflow job for this annotation

GitHub Actions / Check for new typos

perhaps "implemantation" should be "implementation".
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end

cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral
# part we want. The lower word is the fractional part. We can early exit if
# if the fractional part is small enough that no carry from the next lower
# word can cause an overflow and carry into the integer part. This
# happens when the fractional part is bounded by 2^32 - upper which
# can be simplified to just -upper (as an unsigned integer).
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = unsafe_trunc(UInt32, prod >> 32) # integral part
f = unsafe_trunc(UInt32, (prod & 0xffffffff)) # fractional part
if (f <= (UInt32(1) + ~max)) # likely
return unsafe_trunc(UInt32, i), seed
end

# We're in the position where the carry from the next word *might* cause
# a carry to the integral part. The process here is to generate the next
# word, multiply it by the range and add that to the current word. If
# it overflows, the carry propagates to the integer part (return i+1).
# If it can no longer overflow regardless of further lower order bits,
# we are done (return i). If there is still a chance of overflow, we
# repeat the process with the next lower word.
#
# Each *bit* of randomness has a probability of one half of terminating
# this process, so each each word beyond the first has a probability
# of 2^-32 of not terminating the process. That is, we're extremely
# likely to stop very rapidly.
for _ in 1:10
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32)
f2 = unsafe_trunc(UInt32,prod >> 32) # extra fractional part
f *= f2 % UInt32
if f < f2
return i + UInt32(1), seed
end
if (f != 0xffffffff) #unlikely
return i, seed
end
f = prod & 0xffffffff % UInt32
end
# If we get here, we've consumed 32 * max_followup_iterations + 32 bits
# with no firm decision, this gives a bias with probability < 2^-(32*n),
# which is likely acceptable.
return i, seed
end

function multiq_sift_up(heap::taskheap, idx::Int32)
while idx > Int32(1)
Expand Down
3 changes: 3 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,9 @@ JL_DLLEXPORT size_t jl_maxrss(void);
// congruential random number generator
// for a small amount of thread-local randomness

//TODO: utilize https://github.com/openssl/openssl/blob/master/crypto/rand/rand_uniform.c#L13-L99
// for better performance, it does however require making users expect a 32bit random number.

STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT
{
if (max < 2)
Expand Down

0 comments on commit 7223ddf

Please sign in to comment.