diff --git a/src/Stats.jl b/src/Stats.jl index 72eede5567cf9..700505e6ee6a0 100644 --- a/src/Stats.jl +++ b/src/Stats.jl @@ -28,7 +28,9 @@ module Stats tiedrank, weighted_mean, randshuffle!, - randsample + randsample, + sample_by_weights + # Weighted mean # NB: Weights should include 1/n factor @@ -639,4 +641,35 @@ module Stats randsample{T}(x::AbstractVector{T}, n::Integer) = x[randsample(1:length(x), n)] + + ########################################################### + # + # A fast sampling method to sample x ~ p, + # where p is proportional to given weights. + # + # This algorithm performs the sampling without + # computing p (normalizing the weights). + # + # This function is particularly useful in many MCMC + # algorithms. + # + ########################################################### + + function sample_by_weights(w::Vector{Float64}, totalw::Float64) + n = length(w) + t = rand() * totalw + + x = 1 + s = w[1] + + while x < n && s < t + x += 1 + s += w[x] + end + return x + end + + sample_by_weights(w::Vector{Float64}) = sample_by_weights(w, sum(w)) + end # module + diff --git a/test/rands.jl b/test/rands.jl index 1850cf56f8cab..915652e8f79f0 100644 --- a/test/rands.jl +++ b/test/rands.jl @@ -24,7 +24,6 @@ function verify_randsample(r::Matrix{Int}, m::Int, tol::Float64) end - # rand shuffle m = 5 @@ -67,3 +66,17 @@ for i = 1 : n end @test verify_randsample(r, m, 0.02) +# sample_by_weights + +w = [1., 2., 3., 4.] +n = 1000000 + +cnts = zeros(Int, 4) +for i = 1 : n + xi = sample_by_weights(w, 10.) + cnts[xi] += 1 +end +p = cnts / n +p0 = w / sum(w) +@test all(abs(p - p0) .< 0.01) +