Skip to content

Commit

Permalink
rand sampling from an array accepts an optional rng
Browse files Browse the repository at this point in the history
Fixes #5978.
  • Loading branch information
rfourquet committed Nov 19, 2014
1 parent 5e7c424 commit 0054116
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
42 changes: 24 additions & 18 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ rand(T::Type, dims::Dims) = rand(GLOBAL_RNG, T, dims)
rand(T::Type, d1::Int, dims::Int...) = rand(T, tuple(d1, dims...))
rand!(A::AbstractArray) = rand!(GLOBAL_RNG, A)

rand(r::AbstractArray) = rand(GLOBAL_RNG, r)
rand!(r::Range, A::AbstractArray) = rand!(GLOBAL_RNG, r, A)

rand(r::AbstractArray, dims::Dims) = rand(GLOBAL_RNG, r, dims)
rand(r::AbstractArray, dims::Int...) = rand(GLOBAL_RNG, r, dims)

## random floating point values

@inline rand(r::AbstractRNG) = rand(r, CloseOpen)
Expand Down Expand Up @@ -344,58 +350,58 @@ end

# this function uses 32 bit entropy for small ranges of length <= typemax(UInt32) + 1
# RandIntGen is responsible for providing the right value of k
function rand{T<:Union(UInt64, Int64)}(g::RandIntGen{T,UInt64})
function rand{T<:Union(UInt64, Int64)}(mt::MersenneTwister, g::RandIntGen{T,UInt64})
local x::UInt64
if (g.k - 1) >> 32 == 0
x = rand(UInt32)
x = rand(mt, UInt32)
while x > g.u
x = rand(UInt32)
x = rand(mt, UInt32)
end
else
x = rand(UInt64)
x = rand(mt, UInt64)
while x > g.u
x = rand(UInt64)
x = rand(mt, UInt64)
end
end
return reinterpret(T, reinterpret(UInt64, g.a) + rem_knuth(x, g.k))
end

function rand{T<:Integer, U<:Unsigned}(g::RandIntGen{T,U})
x = rand(U)
function rand{T<:Integer, U<:Unsigned}(mt::MersenneTwister, g::RandIntGen{T,U})
x = rand(mt, U)
while x > g.u
x = rand(U)
x = rand(mt, U)
end
(unsigned(g.a) + rem_knuth(x, g.k)) % T
end

rand{T<:Union(Signed,Unsigned,Bool,Char)}(r::UnitRange{T}) = rand(RandIntGen(r))
rand{T<:Union(Signed,Unsigned,Bool,Char)}(mt::MersenneTwister, r::UnitRange{T}) = rand(mt, RandIntGen(r))

# Randomly draw a sample from an AbstractArray r
# (e.g. r is a range 0:2:8 or a vector [2, 3, 5, 7])
rand(r::AbstractArray) = @inbounds return r[rand(1:length(r))]
rand(mt::MersenneTwister, r::AbstractArray) = @inbounds return r[rand(mt, 1:length(r))]

function rand!(g::RandIntGen, A::AbstractArray)
function rand!(mt::MersenneTwister, g::RandIntGen, A::AbstractArray)
for i = 1 : length(A)
@inbounds A[i] = rand(g)
@inbounds A[i] = rand(mt, g)
end
return A
end

rand!{T<:Union(Signed,Unsigned,Bool,Char)}(r::UnitRange{T}, A::AbstractArray) = rand!(RandIntGen(r), A)
rand!{T<:Union(Signed,Unsigned,Bool,Char)}(mt::MersenneTwister, r::UnitRange{T}, A::AbstractArray) = rand!(mt, RandIntGen(r), A)

rand!(r::Range, A::AbstractArray) = _rand!(r, A)
rand!(mt::MersenneTwister, r::Range, A::AbstractArray) = _rand!(mt, r, A)

# TODO: this more general version is "disabled" until #8246 is resolved
function _rand!(r::AbstractArray, A::AbstractArray)
function _rand!(mt::MersenneTwister, r::AbstractArray, A::AbstractArray)
g = RandIntGen(1:(length(r)))
for i = 1 : length(A)
@inbounds A[i] = r[rand(g)]
@inbounds A[i] = r[rand(mt, g)]
end
return A
end

rand{T}(r::AbstractArray{T}, dims::Dims) = _rand!(r, Array(T, dims))
rand(r::AbstractArray, dims::Int...) = rand(r, dims)
rand{T}(mt::MersenneTwister, r::AbstractArray{T}, dims::Dims) = _rand!(mt, r, Array(T, dims))
rand(mt::MersenneTwister, r::AbstractArray, dims::Int...) = rand(mt, r, dims)

## random BitArrays (AbstractRNG)

Expand Down
16 changes: 9 additions & 7 deletions doc/stdlib/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4071,19 +4071,21 @@ A ``MersenneTwister`` RNG can generate random numbers of the following types: ``

Create a ``MersenneTwister`` RNG object. Different RNG objects can have their own seeds, which may be useful for generating different streams of random numbers.

.. function:: rand([rng], [t::Type], [dims...])
.. function:: rand([rng], [S], [dims...])

Generate a random value or an array of random values of the given type, ``t``, which defaults to ``Float64``.
Pick a random element or array of random elements from the set of values specified by ``S``; ``S`` can be

.. function:: rand!([rng], A)
* an indexable collection (for example ``1:n`` or ``['x','y','z']``), or

Populate the array A with random values.
* a type: the set of values to pick from is then equivalent to ``typemin(S):typemax(S)`` for integers, and to [0,1) for floating point numbers;

.. function:: rand(coll, [dims...])
``S`` defaults to ``Float64``.

Pick a random element or array of random elements from the indexable collection ``coll`` (for example, ``1:n`` or ``['x','y','z']``).
.. function:: rand!([rng], A)

Populate the array A with random values.

.. function:: rand!(r, A)
.. function:: rand!([rng], r, A)

Populate the array A with random values drawn uniformly from the range ``r``.

Expand Down
18 changes: 14 additions & 4 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@ rand!(MersenneTwister(0), A)
8690327730555225005 8435109092665372532]

# rand from AbstractArray
@test rand(0:3:1000) in 0:3:1000
coll = Any[2, UInt128(128), big(619), "string", 'c']
@test rand(coll) in coll
@test issubset(rand(coll, 2, 3), coll)
let mt = MersenneTwister()
srand(mt)
@test rand(mt, 0:3:1000) in 0:3:1000
@test issubset(rand!(mt, 0:3:1000, Array(Int, 100)), 0:3:1000)
coll = Any[2, UInt128(128), big(619), "string", 'c']
@test rand(mt, coll) in coll
@test issubset(rand(mt, coll, 2, 3), coll)

# check API with default RNG:
rand(0:3:1000)
rand!(0:3:1000, Array(Int, 100))
rand(coll)
rand(coll, 2, 3)
end

# randn
@test randn(MersenneTwister(42)) == -0.5560268761463861
Expand Down

0 comments on commit 0054116

Please sign in to comment.