Skip to content

Commit

Permalink
Use type inference to decide the return type of the broadcast syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyichao committed Sep 5, 2016
1 parent 9cfc752 commit 7f939ad
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
50 changes: 48 additions & 2 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,51 @@ end

istopsymbol(ex, mod, sym) = ex in (sym, Expr(:(.), mod, Expr(:quote, sym)))

if VERSION < v"0.5.0-dev+4002"
@inline getindex_vararg(idx) = ()
@inline getindex_vararg(idx, arg1) = (arg1[idx],)
@inline getindex_vararg(idx, arg1, arg2) = (arg1[idx], arg2[idx])
@inline getindex_vararg(idx, arg1, arg2, arg3) =
(arg1[idx], arg2[idx], arg3[idx])
@inline getindex_vararg(idx, arg1, arg2, arg3, arg4) =
(arg1[idx], arg2[idx], arg3[idx], arg4[idx])
@inline getindex_vararg(idx, arg1, arg2, arg3, arg4, args...) =
(arg1[idx], arg2[idx], arg3[idx], arg4[idx],
getindex_vararg(idx, args...)...)
function rewrite_broadcast(f, args)
nargs = length(args)
is_vararg = Bool[isexpr(args[i], :...) for i in 1:nargs]
names = [gensym("broadcast") for i in 1:nargs]
new_args = [is_vararg[i] ? Expr(:..., names[i]) : names[i]
for i in 1:nargs]
getidx_for = function (i, idxvar)
if is_vararg[i]
Expr(:..., :($Compat.getindex_vararg($idxvar, $(names[i])...)))
else
:($(names[i])[$idxvar])
end
end
@gensym allidx
@gensym newshape
@gensym res1d
@gensym idx
return quote
# The `local` makes sure type inference can infer the type even
# in global scope as long as the input is type stable
local $(names...)
$([:($(names[i]) = $(is_vararg[i] ? args[i].args[1] : args[i]))
for i in 1:nargs]...)
local $newshape = $Base.Broadcast.broadcast_shape($(new_args...))
local $allidx = $(VERSION >= v"0.4.0-dev+1624" ?
:($Compat.eachindex($(new_args...))) :
:(1:prod($newshape)))
local $res1d = [$f($([getidx_for(i, idx) for i in 1:nargs]...))
for $idx in $allidx]
$Base.reshape($res1d, $newshape)
end
end
end

function _compat(ex::Expr)
if ex.head === :call
f = ex.args[1]
Expand Down Expand Up @@ -549,11 +594,12 @@ function _compat(ex::Expr)
return Expr(ex.head, _compat(ex.args[1]), QuoteNode(ex.args[2].args[1].args[1]))
elseif isexpr(ex.args[2], :tuple)
# f.(arg1, arg2...) -> broadcast(f, arg1, arg2...)
return Expr(:call, :broadcast, _compat(ex.args[1]), map(_compat, ex.args[2].args)...)
return rewrite_broadcast(_compat(ex.args[1]),
map(_compat, ex.args[2].args))
elseif !isa(ex.args[2], QuoteNode) &&
!(isexpr(ex.args[2], :quote) && isa(ex.args[2].args[1], Symbol))
# f.(arg) -> broadcast(f, arg)
return Expr(:call, :broadcast, _compat(ex.args[1]), _compat(ex.args[2]))
return rewrite_broadcast(_compat(ex.args[1]), [_compat(ex.args[2])])
end
elseif ex.head === :import
if VERSION < v"0.5.0-dev+4340" && length(ex.args) == 2 && ex.args[1] === :Base && ex.args[2] === :show
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,11 @@ let x = rand(3), y = rand(3)
@test @compat(sin.(cos.(x))) == map(x -> sin(cos(x)), x)
@test @compat(atan2.(sin.(y),x)) == broadcast(atan2,map(sin,y),x)
end
# Do the following in global scope to make sure inference is able to handle it
@test @compat(sin.([1, 2])) == [sin(1), sin(2)]
@test isa(@compat(sin.([1, 2])), Vector{Float64})
@test @compat(atan2.([1, 2], [2, 3])) == [atan2(1, 2), atan2(2, 3)]
@test isa(@compat(atan2.([1, 2], [2, 3])), Vector{Float64})

if VERSION v"0.4.0-dev+3732"
@test Symbol("foo") === :foo
Expand Down

0 comments on commit 7f939ad

Please sign in to comment.