Skip to content

Commit

Permalink
more type-stable reductions (fix JuliaLang#6069)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Mar 11, 2014
1 parent 764098f commit d114e11
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 47 deletions.
141 changes: 94 additions & 47 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

###### higher level reduction functions ######

# Note that getting type-stable results from reduction functions,
# or at least having type-stable loops, is nontrivial (#6069).

# reduce

function reduce(op::Callable, itr) # this is a left fold
Expand All @@ -19,17 +22,24 @@ function reduce(op::Callable, itr) # this is a left fold
return op() # empty collection
end
(v, s) = next(itr, s)
while !done(itr, s)
if done(itr, s)
return v
else # specialize for length > 1 to have a hopefully type-stable loop
(x, s) = next(itr, s)
v = op(v,x)
result = op(v, x)
while !done(itr, s)
(x, s) = next(itr, s)
result = op(result, x)
end
return result
end
return v
end

# pairwise reduction, requires n > 1 (to allow type-stable loop)
function r_pairwise(op::Callable, A::AbstractArray, i1,n)
if n < 128
@inbounds v = A[i1]
for i = i1+1:i1+n-1
@inbounds v = op(A[i1], A[i1+1])
for i = i1+2:i1+n-1
@inbounds v = op(v,A[i])
end
return v
Expand All @@ -41,12 +51,12 @@ end

function reduce(op::Callable, A::AbstractArray)
n = length(A)
n == 0 ? op() : r_pairwise(op,A, 1,n)
n == 0 ? op() : n == 1 ? A[1] : r_pairwise(op,A, 1,n)
end

function reduce(op::Callable, v0, A::AbstractArray)
n = length(A)
n == 0 ? v0 : op(v0, r_pairwise(op,A, 1,n))
n == 0 ? v0 : n == 1 ? op(v0, A[1]) : op(v0, r_pairwise(op,A, 1,n))
end

function reduce(op::Callable, v0, itr)
Expand Down Expand Up @@ -169,46 +179,53 @@ end
function sum(itr)
s = start(itr)
if done(itr, s)
return 0
if method_exists(eltype, (typeof(itr),))
T = eltype(itr)
return zero(T) + zero(T)
else
throw(ArgumentError("sum(itr) is undefined for empty collections; instead, do isempty(itr) ? z : sum(itr), where z is the correct type of zero for your sum"))
end
end
(v, s) = next(itr, s)
done(itr, s) && return v + zero(v) # adding zero for type stability
# specialize for length > 1 to have type-stable loop
(x, s) = next(itr, s)
result = v + x
while !done(itr, s)
(x, s) = next(itr, s)
v += x
result += x
end
return v
return result
end

sum(A::AbstractArray{Bool}) = countnz(A)

# a fast implementation of sum in sequential order (from left to right)
# a fast implementation of sum in sequential order (from left to right).
# to allow type-stable loops, requires length > 1
function sum_seq{T}(a::AbstractArray{T}, ifirst::Int, ilast::Int)

@inbounds if ifirst + 3 >= ilast # a has at most four elements
if ifirst > ilast
return zero(T)
else
i = ifirst
s = a[i]
while i < ilast
s += a[i+=1]
end
return s

@inbounds if ifirst + 6 >= ilast # length(a) < 8
i = ifirst
s = a[i] + a[i+1]
i = i+1
while i < ilast
s += a[i+=1]
end
return s

else # a has more than four elements
else # length(a) >= 8

# more effective utilization of the instruction
# pipeline through manually unrolling the sum
# into four-way accumulation. Benchmark shows
# that this results in about 2x speed-up.

s1 = a[ifirst]
s2 = a[ifirst + 1]
s3 = a[ifirst + 2]
s4 = a[ifirst + 3]
s1 = a[ifirst] + a[ifirst + 4]
s2 = a[ifirst + 1] + a[ifirst + 5]
s3 = a[ifirst + 2] + a[ifirst + 6]
s4 = a[ifirst + 3] + a[ifirst + 7]

i = ifirst + 4
i = ifirst + 8
il = ilast - 3
while i <= il
s1 += a[i]
Expand Down Expand Up @@ -243,6 +260,7 @@ end
# Note: sum_seq uses four accumulators, so each accumulator gets at most 256 numbers
const PAIRWISE_SUM_BLOCKSIZE = 1024

# note: requires length > 1, due to sum_seq
function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int)
# bsiz: maximum block size

Expand All @@ -254,18 +272,29 @@ function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int)
end
end

sum(a::AbstractArray) = sum_pairwise(a, 1, length(a))
sum{T<:Integer}(a::AbstractArray{T}) = sum_seq(a, 1, length(a))
function sum{T}(a::AbstractArray{T})
n = length(a)
n == 0 && return zero(T) + zero(T)
n == 1 && return a[1] + zero(a[1])
sum_pairwise(a, 1, length(a))
end

function sum{T<:Integer}(a::AbstractArray{T})
n = length(a)
n == 0 && return zero(T) + zero(T)
n == 1 && return a[1] + zero(a[1])
sum_seq(a, 1, length(a))
end

# Kahan (compensated) summation: O(1) error growth, at the expense
# of a considerable increase in computational expense.
function sum_kbn{T<:FloatingPoint}(A::AbstractArray{T})
n = length(A)
if (n == 0)
return zero(T)
return zero(T)+zero(T)
end
s = A[1]
c = zero(T)
s = A[1]+zero(T)
c = zero(T)+zero(T)
for i in 2:n
Ai = A[i]
t = s + Ai
Expand All @@ -286,29 +315,40 @@ end
function prod(itr)
s = start(itr)
if done(itr, s)
return *()
if method_exists(eltype, (typeof(itr),))
T = eltype(itr)
return one(T) * one(T)
else
throw(ArgumentError("prod(itr) is undefined for empty collections; instead, do isempty(itr) ? o : prod(itr), where o is the correct type of identity for your product"))
end
end
(v, s) = next(itr, s)
done(itr, s) && return v * one(v) # multiplying by one for type stability
# specialize for length > 1 to have type-stable loop
(x, s) = next(itr, s)
result = v * x
while !done(itr, s)
(x, s) = next(itr, s)
v = v*x
result *= x
end
return v
return result
end

prod(A::AbstractArray{Bool}) =
error("use all() instead of prod() for boolean arrays")

function prod_rgn{T}(A::AbstractArray{T}, first::Int, last::Int)
if first > last
return one(T)
return one(T) * one(T)
end
i = first
v = A[i]
@inbounds v = A[i]
i == last && return v * one(v)
@inbounds result = v * A[i+=1]
while i < last
@inbounds v *= A[i+=1]
@inbounds result *= A[i+=1]
end
return v
return result
end
prod{T}(A::AbstractArray{T}) = prod_rgn(A, 1, length(A))

Expand Down Expand Up @@ -467,11 +507,17 @@ function mapreduce(f::Callable, op::Callable, itr)
end
(x, s) = next(itr, s)
v = f(x)
while !done(itr, s)
if done(itr, s)
return v
else # specialize for length > 1 to have a hopefully type-stable loop
(x, s) = next(itr, s)
v = op(v,f(x))
result = op(v, f(x))
while !done(itr, s)
(x, s) = next(itr, s)
result = op(result, f(x))
end
return result
end
return v
end

function mapreduce(f::Callable, op::Callable, v0, itr)
Expand All @@ -482,10 +528,11 @@ function mapreduce(f::Callable, op::Callable, v0, itr)
return v
end

# pairwise reduction, requires n > 1 (to allow type-stable loop)
function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n)
if n < 128
@inbounds v = f(A[i1])
for i = i1+1:i1+n-1
@inbounds v = op(f(A[i1]), f(A[i1+1]))
for i = i1+2:i1+n-1
@inbounds v = op(v,f(A[i]))
end
return v
Expand All @@ -496,11 +543,11 @@ function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n)
end
function mapreduce(f::Callable, op::Callable, A::AbstractArray)
n = length(A)
n == 0 ? op() : mr_pairwise(f,op,A, 1,n)
n == 0 ? op() : n == 1 ? f(A[1]) : mr_pairwise(f,op,A, 1,n)
end
function mapreduce(f::Callable, op::Callable, v0, A::AbstractArray)
n = length(A)
n == 0 ? v0 : op(v0, mr_pairwise(f,op,A, 1,n))
n == 0 ? v0 : n == 1 ? op(v0, f(A[1])) : op(v0, mr_pairwise(f,op,A, 1,n))
end

# specific mapreduce functions
Expand Down
29 changes: 29 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,35 @@ end

@test sum(z) == sum(z,(1,2,3,4))[1] == 136

# check variants of summation for type-stability and other issues (#6069)
sum2(itr) = invoke(sum, (Any,), itr)
plus(x,y) = x + y
plus() = 0
sum3(A) = reduce(plus, A)
sum4(itr) = invoke(reduce, (Function, Any), plus, itr)
sum5(A) = reduce(plus, 0, A)
sum6(itr) = invoke(reduce, (Function, Int, Any), plus, 0, itr)
sum7(A) = mapreduce(x->x, plus, A)
sum8(itr) = invoke(mapreduce, (Function, Function, Any), x->x, plus, itr)
sum9(A) = mapreduce(x->x, plus, 0, A)
sum10(itr) = invoke(mapreduce, (Function, Function, Int, Any), x->x,plus,0,itr)
for f in (sum2, sum3, sum4, sum5, sum6, sum7, sum8, sum9, sum10)
@test sum(z) == f(z)
@test sum(Int[]) == f(Int[]) == 0
@test sum(Int[7]) == f(Int[7]) == 7
if f == sum3 || f == sum4 || f == sum7 || f == sum8
@test typeof(f(Int8[])) == typeof(f(Int8[1 7]))
else
@test typeof(f(Int8[])) == typeof(f(Int8[1])) == typeof(f(Int8[1 7]))
end
end
@test typeof(sum(Int8[])) == typeof(sum(Int8[1])) == typeof(sum(Int8[1 7]))

prod2(itr) = invoke(prod, (Any,), itr)
@test prod(Int[]) == prod2(Int[]) == 1
@test prod(Int[7]) == prod2(Int[7]) == 7
@test typeof(prod(Int8[])) == typeof(prod(Int8[1])) == typeof(prod(Int8[1 7])) == typeof(prod2(Int8[])) == typeof(prod2(Int8[1])) == typeof(prod2(Int8[1 7]))

v = cell(2,2,1,1)
v[1,1,1,1] = 28.0
v[1,2,1,1] = 36.0
Expand Down

0 comments on commit d114e11

Please sign in to comment.