Skip to content

Commit

Permalink
Use type inference to decide the return type of the broadcast syntax
Browse files Browse the repository at this point in the history
Add much more tests to ensure the behavior of the broadcast syntax is
as consistent as possible on different julia versions.
  • Loading branch information
yuyichao committed Sep 8, 2016
1 parent ac912ce commit fbf0173
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 7 deletions.
152 changes: 145 additions & 7 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ using Base.Meta
"""Get just the function part of a function declaration."""
withincurly(ex) = isexpr(ex, :curly) ? ex.args[1] : ex

if VERSION < v"0.4.0-dev+2254"
immutable Val{T} end
export Val
end

if VERSION < v"0.4.0-dev+1419"
export UInt, UInt8, UInt16, UInt32, UInt64, UInt128
const UInt = Uint
Expand Down Expand Up @@ -431,6 +436,143 @@ end

istopsymbol(ex, mod, sym) = ex in (sym, Expr(:(.), mod, Expr(:quote, sym)))

if VERSION < v"0.5.0-dev+4002"
typealias Array0D{T} Array{T,0}
@inline broadcast_getindex(arg, idx) = arg[(idx - 1) % length(arg) + 1]
# Optimize for single element
@inline broadcast_getindex(arg::Number, idx) = arg
@inline broadcast_getindex(arg::Array0D, idx) = arg[1]

# If we know from syntax level that we don't need wrapping
@inline broadcast_getindex_naive(arg, idx) = arg[idx]
@inline broadcast_getindex_naive(arg::Number, idx) = arg
@inline broadcast_getindex_naive(arg::Array0D, idx) = arg[1]

# For vararg support
@inline getindex_vararg(idx) = ()
@inline getindex_vararg(idx, arg1) = (broadcast_getindex(arg1, idx),)
@inline getindex_vararg(idx, arg1, arg2) =
(broadcast_getindex(arg1, idx), broadcast_getindex(arg2, idx))
@inline getindex_vararg(idx, arg1, arg2, arg3, args...) =
(broadcast_getindex(arg1, idx), broadcast_getindex(arg2, idx),
broadcast_getindex(arg3, idx), getindex_vararg(idx, args...)...)

@inline getindex_naive_vararg(idx) = ()
@inline getindex_naive_vararg(idx, arg1) =
(broadcast_getindex_naive(arg1, idx),)
@inline getindex_naive_vararg(idx, arg1, arg2) =
(broadcast_getindex_naive(arg1, idx),
broadcast_getindex_naive(arg2, idx))
@inline getindex_naive_vararg(idx, arg1, arg2, arg3, args...) =
(broadcast_getindex_naive(arg1, idx),
broadcast_getindex_naive(arg2, idx),
broadcast_getindex_naive(arg3, idx),
getindex_naive_vararg(idx, args...)...)

# Decide if the result should be scalar or array
# `size() === ()` is not good enough since broadcasting on
# a scalar should return a scalar where as broadcasting on a 0-dim
# array should return a 0-dim array.
@inline should_return_array(::Val{true}, args...) = Val{true}()
@inline should_return_array(::Val{false}) = Val{false}()
@inline should_return_array(::Val{false}, arg1) = Val{false}()
@inline should_return_array(::Val{false}, arg1::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1::AbstractArray,
arg2::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1,
arg2::AbstractArray) = Val{true}()
@inline should_return_array(::Val{false}, arg1::AbstractArray,
arg2) = Val{true}()
@inline should_return_array(::Val{false}, arg1, arg2) = Val{false}()
@inline should_return_array(::Val{false}, arg1, arg2, args...) =
should_return_array(should_return_array(Val{false}(), arg1, arg2),
args...)

@inline broadcast_return(res1d, shp, ret_ary::Val{false}) = res1d[1]
@inline broadcast_return(res1d, shp, ret_ary::Val{true}) = reshape(res1d, shp)

@inline need_full_getindex(shp) = false
@inline need_full_getindex(shp, arg1::Number) = false
@inline need_full_getindex(shp, arg1::Array0D) = false
@inline need_full_getindex(shp, arg1) = shp != size(arg1)
@inline need_full_getindex(shp, arg1, arg2) =
need_full_getindex(shp, arg1) || need_full_getindex(shp, arg2)
@inline need_full_getindex(shp, arg1, arg2, arg3, args...) =
need_full_getindex(shp, arg1, arg2) || need_full_getindex(shp, arg3) ||
need_full_getindex(shp, args...)

function rewrite_broadcast(f, args)
nargs = length(args)
# This actually allows multiple splatting...,
# which is now allowed on master.
# The previous version that simply calls broadcast so removing that
# will be breaking. Oh, well....
is_vararg = Bool[isexpr(args[i], :...) for i in 1:nargs]
names = [gensym("broadcast") for i in 1:nargs]
new_args = [is_vararg[i] ? Expr(:..., names[i]) : names[i]
for i in 1:nargs]
# Optimize for common case where we know the index doesn't need
# any wrapping
naive_getidx_for = function (i, idxvar)
if is_vararg[i]
Expr(:..., :($Compat.getindex_naive_vararg($idxvar,
$(names[i])...)))
else
:($Compat.broadcast_getindex_naive($(names[i]), $idxvar))
end
end
always_naive = nargs == 1 && !is_vararg[1]
getidx_for = if always_naive
naive_getidx_for
else
function (i, idxvar)
if is_vararg[i]
Expr(:..., :($Compat.getindex_vararg($idxvar,
$(names[i])...)))
else
:($Compat.broadcast_getindex($(names[i]), $idxvar))
end
end
end
@gensym allidx
@gensym newshape
@gensym res1d
@gensym idx
@gensym ret_ary

res1d_expr = quote
$res1d = [$f($([naive_getidx_for(i, idx) for i in 1:nargs]...))
for $idx in $allidx]
end
if !always_naive
res1d_expr = quote
if $Compat.need_full_getindex($newshape, $(new_args...))
$res1d = [$f($([getidx_for(i, idx) for i in 1:nargs]...))
for $idx in $allidx]
else
$res1d_expr
end
end
end

return quote
# The `local` makes sure type inference can infer the type even
# in global scope as long as the input is type stable
local $(names...)
$([:($(names[i]) = $(is_vararg[i] ? args[i].args[1] : args[i]))
for i in 1:nargs]...)
local $newshape = $(Base.Broadcast).broadcast_shape($(new_args...))
# `eachindex` is not generic enough
local $allidx = 1:prod($newshape)
local $ret_ary = $Compat.should_return_array(Val{false}(),
$(new_args...))
local $res1d
$res1d_expr
$Compat.broadcast_return($res1d, $newshape, $ret_ary)
end
end
end

function _compat(ex::Expr)
if ex.head === :call
f = ex.args[1]
Expand Down Expand Up @@ -549,11 +691,12 @@ function _compat(ex::Expr)
return Expr(ex.head, _compat(ex.args[1]), QuoteNode(ex.args[2].args[1].args[1]))
elseif isexpr(ex.args[2], :tuple)
# f.(arg1, arg2...) -> broadcast(f, arg1, arg2...)
return Expr(:call, :broadcast, _compat(ex.args[1]), map(_compat, ex.args[2].args)...)
return rewrite_broadcast(_compat(ex.args[1]),
map(_compat, ex.args[2].args))
elseif !isa(ex.args[2], QuoteNode) &&
!(isexpr(ex.args[2], :quote) && isa(ex.args[2].args[1], Symbol))
# f.(arg) -> broadcast(f, arg)
return Expr(:call, :broadcast, _compat(ex.args[1]), _compat(ex.args[2]))
return rewrite_broadcast(_compat(ex.args[1]), [_compat(ex.args[2])])
end
elseif ex.head === :import
if VERSION < v"0.5.0-dev+4340" && length(ex.args) == 2 && ex.args[1] === :Base && ex.args[2] === :show
Expand Down Expand Up @@ -668,11 +811,6 @@ if VERSION < v"0.4.0-dev+4502"
export keytype, valtype
end

if VERSION < v"0.4.0-dev+2254"
immutable Val{T} end
export Val
end

if VERSION < v"0.4.0-dev+2840"
Base.qr(A, ::Type{Val{true}}; thin::Bool=true) =
Base.qr(A, pivot=true, thin=thin)
Expand Down
30 changes: 30 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,36 @@ let x = rand(3), y = rand(3)
@test @compat(sin.(cos.(x))) == map(x -> sin(cos(x)), x)
@test @compat(atan2.(sin.(y),x)) == broadcast(atan2,map(sin,y),x)
end
let x0 = Array(Float64), v, v0
x0[1] = rand()
v0 = @compat sin.(x0)
@test isa(v0, Array{Float64,0})
v = @compat sin.(x0[1])
@test isa(v, Float64)
@test v == v0[1] == sin(x0[1])
end
let x = rand(2, 2), v
v = @compat sin.(x)
@test isa(v, Array{Float64,2})
@test v == [sin(x[1, 1]) sin(x[1, 2]);
sin(x[2, 1]) sin(x[2, 2])]
end
let x1 = [1, 2, 3], x2 = ([3, 4, 5],), v
v = @compat atan2.(x1, x2...)
@test isa(v, Vector{Float64})
@test v == [atan2(1, 3), atan2(2, 4), atan2(3, 5)]
end
# Do the following in global scope to make sure inference is able to handle it
@test @compat(sin.([1, 2])) == [sin(1), sin(2)]
@test isa(@compat(sin.([1, 2])), Vector{Float64})
@test @compat(atan2.(1, [2, 3])) == [atan2(1, 2), atan2(1, 3)]
@test isa(@compat(atan2.(1, [2, 3])), Vector{Float64})
@test @compat(atan2.([1, 2], [2, 3])) == [atan2(1, 2), atan2(2, 3)]
@test isa(@compat(atan2.([1, 2], [2, 3])), Vector{Float64})
# And make sure it is actually inferrable
f15032(a) = @compat sin.(a)
@inferred f15032([1, 2, 3])
@inferred f15032([1.0, 2.0, 3.0])

if VERSION v"0.4.0-dev+3732"
@test Symbol("foo") === :foo
Expand Down

0 comments on commit fbf0173

Please sign in to comment.