Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive and dotted replacement with muladd #57

Merged
merged 4 commits into from
Jul 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 120 additions & 23 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,138 @@ macro def(name, definition)
end
end

"""
@muladd ex

Convert every combined multiplication and addition in `ex` into a call of `muladd`. If both
of the involved operators are dotted, `muladd` is applied as a "dot call".
"""
macro muladd(ex)
esc(to_muladd(ex))
end

function to_muladd(ex)
is_add_operation(ex) || return ex
function to_muladd(ex::Expr)
if !isaddition(ex)
if ex.head == :macrocall && length(ex.args)==2 && ex.args[1] == Symbol("@__dot__")
# expand @. macros first (enables use of @. inside of @muladd expression)
return to_muladd(Base.Broadcast.__dot__(ex.args[2]))
else
# if expression is no sum apply the reduction to its arguments
return Expr(ex.head, to_muladd.(ex.args)...)
end
end

all_operands = ex.args[2:end]
mul_operands = filter(is_mul_operation, all_operands)
odd_operands = filter(x->!is_mul_operation(x), all_operands)
# retrieve summands of addition and split them into two groups, one with expressions
# of multiplications and one with other expressions
# if addition is a dot call multiplications must be dot calls as well; if the addition
# is a regular operation only regular multiplications are filtered
all_operands = to_muladd.(operands(ex))
if isdotcall(ex)
mul_operands = filter(x->isdotcall(x, :*), all_operands)
odd_operands = filter(x->!isdotcall(x, :*), all_operands)
else
mul_operands = filter(x->isoperation(x, :*), all_operands)
odd_operands = filter(x->!isoperation(x, :*), all_operands)
end

muladd_operands = collect(zip(
to_muladd.((x->x.args[2]).(mul_operands)),
to_muladd.((x->x.args[3]).(mul_operands))))
# define summands that are reduced with muladd and the initial element of the reduction
if isempty(odd_operands)
# if all summands are multiplications one of these summands is
# the initial element of the reduction
to_be_muladded = mul_operands[1:end-1]
last_operation = mul_operands[end]
else
to_be_muladded = mul_operands

# expressions that are no multiplications are summed up in a separate expression
# that is the initial element of the reduction
# if the original addition was a dot call this expression also is a dot call
if length(odd_operands) == 1
last_operation = odd_operands[1]
elseif isdotcall(ex)
# make sure returned expression has same style as original expression
if ex.head == :.
last_operation = Expr(:., :+, Expr(:tuple, odd_operands...))
else
last_operation = Expr(:call, :.+, odd_operands...)
end
else
last_operation = Expr(:call, :+, odd_operands...)
end
end

if isempty(odd_operands)
to_be_muladded = muladd_operands[1:end-1]
last_operation = :($(muladd_operands[end][1]) * $(muladd_operands[end][2]))
else
to_be_muladded = muladd_operands
last_operation = make_addition(odd_operands)
end
# reduce sum to a composition of muladd
foldr(last_operation, to_be_muladded) do xs, r
# retrieve factors of multiplication that will be reduced next
xs_operands = operands(xs)

# first factor is always first operand
xs_factor1 = xs_operands[1]

# second factor is an expression of a multiplication if there are more than
# two operands
# if the original multiplication was a dot call this expression also is a dot call
if length(xs_operands) == 2
xs_factor2 = xs_operands[2]
elseif isdotcall(xs)
xs_factor2 = Expr(:., :*, Expr(:tuple, xs_operands[2:end]...))
else
xs_factor2 = Expr(:call, :*, xs_operands[2:end]...)
end

foldr(last_operation, to_be_muladded) do xs, r
:($(Base.muladd)($(xs[1]), $(xs[2]), $r))
end
# create a dot call if both involved operators are dot calls
if isdotcall(ex)
Expr(:., Base.muladd, Expr(:tuple, xs_factor1, xs_factor2, r))
else
Expr(:call, Base.muladd, xs_factor1, xs_factor2, r)
end
end
end
to_muladd(ex) = ex

"""
isoperation(ex, op::Symbol)

is_operation(ex::Expr, op::Symbol) = ex.head == :call && !isempty(ex.args) && ex.args[1] == op
is_operation(ex, op::Symbol) = false
Determine whether `ex` is a call of operation `op`.
"""
isoperation(ex::Expr, op::Symbol) =
ex.head == :call && !isempty(ex.args) && ex.args[1] == op
isoperation(ex, op::Symbol) = false

"""
isdotcall(ex[, op])

Determine whether `ex` is a dot call and, in case `op` is specified, whether it calls
operator `op`.
"""
isdotcall(ex::Expr) = !isempty(ex.args) &&
(ex.head == :. ||
(ex.head == :call && !isempty(ex.args) && first(string(ex.args[1])) == '.'))
isdotcall(ex) = false

isdotcall(ex::Expr, op::Symbol) = isdotcall(ex) &&
(ex.args[1] == op || ex.args[1] == Symbol('.', op))
isdotcall(ex, op::Symbol) = false

"""
isaddition(ex)

Determine whether `ex` is an expression of an addition.
"""
isaddition(ex) = isoperation(ex, :+) || isdotcall(ex, :+)

is_add_operation(ex) = is_operation(ex, :+)
is_mul_operation(ex) = is_operation(ex, :*)
"""
operands(ex)

make_addition(args) = length(args) == 1 ? args[1] : Expr(:call, :+, args...)
Return arguments of function call in `ex`.
"""
function operands(ex::Expr)
if ex.head == :. && length(ex.args) == 2 && typeof(ex.args[2]) <: Expr
ex.args[2].args
else
ex.args[2:end]
end
end

realtype{T}(::Type{T}) = T
realtype{T}(::Type{Complex{T}}) = T
Expand Down
34 changes: 34 additions & 0 deletions test/muladd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using DiffEqBase, Base.Test

# Basic expressions
@test macroexpand(:(@muladd a*b+c)) == :($(Base.muladd)(a, b, c))
@test macroexpand(:(@muladd c+a*b)) == :($(Base.muladd)(a, b, c))
@test macroexpand(:(@muladd b*a+c)) == :($(Base.muladd)(b, a, c))
@test macroexpand(:(@muladd c+b*a)) == :($(Base.muladd)(b, a, c))

# Multiple multiplications
@test macroexpand(:(@muladd a*b+c*d)) == :($(Base.muladd)(a, b, c*d))
@test macroexpand(:(@muladd a*b+c*d+e*f)) == :($(Base.muladd)(a, b,
$(Base.muladd)(c, d, e*f)))
@test macroexpand(:(@muladd a*(b*c+d)+e)) == :($(Base.muladd)(a,
$(Base.muladd)(b, c, d), e))

# Dot calls
@test macroexpand(:(@. @muladd a*b+c)) == :($(Base.muladd).(a, b, c))
@test macroexpand(:(@muladd @. a*b+c)) == :($(Base.muladd).(a, b, c))
@test macroexpand(:(@muladd a.*b+c)) == :(a.*b+c)
@test macroexpand(:(@muladd a*b.+c)) == :(a*b.+c)
@test macroexpand(:(@muladd f.(a)*b+c)) == :($(Base.muladd)(f.(a), b, c))
@test macroexpand(:(@muladd a*f.(b)+c)) == :($(Base.muladd)(a, f.(b), c))
@test macroexpand(:(@muladd a*b+f.(c))) == :($(Base.muladd)(a, b, f.(c)))

# Nested expressions
@test macroexpand(:(@muladd f(x, y, z) = x*y+z)) == :(f(x, y, z) = $(Base.muladd)(x, y, z))
@test macroexpand(:(@muladd function f(x, y, z) x*y+z end)) ==
:(function f(x, y, z) $(Base.muladd)(x, y, z) end)
@test macroexpand(:(@muladd for i in 1:n z = x*i + y end)) ==
:(for i in 1:n z = $(Base.muladd)(x, i, y) end)

# Additional factors
@test macroexpand(:(@muladd a*b*c+d)) == :($(Base.muladd)(a, b*c, d))
@test macroexpand(:(@muladd a*b*c*d+e)) == :($(Base.muladd)(a, b*c*d, e))
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ using Base.Test
@time @testset "Callbacks" begin include("callbacks.jl") end
@time @testset "Constructed Parameterized Functions" begin include("constructed_pf_test.jl") end
@time @testset "Plot Variables" begin include("plot_vars.jl") end
@time @testset "Muladd Macro" begin include("muladd.jl") end