Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wfact and Wfact_t support #149

Merged
merged 2 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# convenience macro
macro wrap_h(signature)
Meta.isexpr(signature, :call) ||
throw(ArgumentError("signature has to be a function call expression"))

name = signature.args[1]
args = signature.args[2:end]
args_wo_h = [arg for arg in args if arg !== :h]

quote
if f.$name === nothing
nothing
else
if isinplace(f)
let _f = f.$name, h = h
($(args_wo_h...),) -> _f($(args...))
end
else
let _f = f.$name, h = h
($(args_wo_h[2:end]...),) -> _f($(args[2:end]...))
end
end
end
end |> esc
end

struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBase.AbstractODEFunction{iip}
f::F
h::H
Expand All @@ -14,26 +40,17 @@ struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBas
end

function ODEFunctionWrapper(f::DDEFunction, h)
if f.jac === nothing
jac = nothing
else
if isinplace(f)
jac = let f_jac = f.jac, h = h
(J, u, p, t) -> f_jac(J, u, h, p, t)
end
else
jac = let f_jac = f.jac, h = h
(u, p, t) -> f_jac(u, h, p, t)
end
end
end
# wrap functions
jac = @wrap_h jac(J, u, h, p, t)
Wfact = @wrap_h Wfact(W, u, h, p, dtgamma, t)
Wfact_t = @wrap_h Wfact_t(W, u, h, p, dtgamma, t)

ODEFunctionWrapper{isinplace(f),typeof(f.f),typeof(h),typeof(f.mass_matrix),
typeof(f.analytic),typeof(f.tgrad),typeof(jac),
typeof(f.jac_prototype),typeof(f.Wfact),typeof(f.Wfact_t),
typeof(f.jac_prototype),typeof(Wfact),typeof(Wfact_t),
typeof(f.paramjac),typeof(f.syms),typeof(f.colorvec)}(
f.f, h, f.mass_matrix, f.analytic, f.tgrad, jac,
f.jac_prototype, f.Wfact, f.Wfact_t, f.paramjac, f.syms,
f.jac_prototype, Wfact, Wfact_t, f.paramjac, f.syms,
f.colorvec)
end

Expand Down
169 changes: 150 additions & 19 deletions test/interface/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,180 @@ using Test
nothing
end

function g(J, u, h, p, t)
njacs = Ref(0)
function jac(J, u, h, p, t)
njacs[] += 1
J[1, 1] = 1 - h(p, t - 1)[1]
nothing
end

nWfacts = Ref(0)
function Wfact(W, u, h, p, dtgamma, t)
nWfacts[] += 1
W[1,1] = dtgamma * (1 - h(p, t - 1)[1]) - 1
nothing
end

nWfact_ts = Ref(0)
function Wfact_t(W, u, h, p, dtgamma, t)
nWfact_ts[] += 1
W[1,1] = 1 - h(p, t - 1)[1] - inv(dtgamma)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
nothing
end

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{true}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0); constant_lags = [1])
prob_jac = remake(prob; f = DDEFunction{true}(f; jac = jac))
prob_Wfact = remake(prob; f = DDEFunction{true}(f; Wfact = Wfact))
prob_Wfact_t = remake(prob; f = DDEFunction{true}(f; Wfact_t = Wfact_t))

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))
sol = solve(prob, MethodOfSteps(alg))

## Jacobian
njacs[] = 0
sol_jac = solve(prob_jac, MethodOfSteps(alg))

# check number of function evaluations
@test !iszero(njacs[])
@test njacs[] == sol_jac.destats.njacs
if alg isa Rosenbrock23
@test njacs[] == sol_jac.destats.nw
else
@test_broken njacs[] == sol_jac.destats.nw
end

# check resulting solution
@test sol.t ≈ sol_jac.t
@test sol.u ≈ sol_jac.u

## Wfact
nWfacts[] = 0
sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg))

# check number of function evaluations
if alg isa Rosenbrock23
@test !iszero(nWfacts[])
@test nWfacts[] == njacs[]
@test iszero(sol_Wfact.destats.njacs)
else
@test_broken !iszero(nWfacts[])
@test_broken nWfacts[] == njacs[]
@test_broken iszero(sol_Wfact.destats.njacs)
end
@test_broken nWfacts[] == sol_Wfact.destats.nw

# check resulting solution
@test sol.t ≈ sol_Wfact.t
@test sol.u ≈ sol_Wfact.u

## Wfact_t
nWfact_ts[] = 0
sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg))

# check number of function evaluations
if alg isa Rosenbrock23
@test_broken !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test_broken iszero(sol_Wfact_t.destats.njacs)
else
@test !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test iszero(sol_Wfact_t.destats.njacs)
end
@test_broken nWfact_ts[] == sol_Wfact_t.destats.nw

@test sol_wo_jac.t ≈ sol_w_jac.t
@test sol_wo_jac.u ≈ sol_w_jac.u
# check resulting solution
if alg isa Rosenbrock23
@test sol.t ≈ sol_Wfact_t.t
@test sol.u ≈ sol_Wfact_t.u
else
@test_broken sol.t ≈ sol_Wfact_t.t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that means it's not working, since IIRC at this point only Rosenbrock23 is using the untransformed version still?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, as I said, maybe it's a bit too ambitious at this point. The main motivation of the tests is that:

  1. If we specify the Jacobian, the number of Jacobian and W evaluations should be equal to the non-zero value of our counter, and the solution should be the same.
  2. If we specify Wfact or Wfact_t, the number of Jacobian evaluations should be zero and the number of W evaluations should be equal to the non-zero value of our counter, and the solution should be the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see: the time points are equal, not that the value is correct. Yeah, I am not sure this will be the case, since symbolic CSE stuff happens on the symbolic factorization (the main source of the speedup)

@test_broken sol.u ≈ sol_Wfact_t.u
end
end
end

@testset "out-of-place" begin
# define functions (Hutchinson's equation)
f(u, h, p, t) = [u[1] * (1 - h(p, t - 1)[1])]
f(u, h, p, t) = u[1] .* (1 .- h(p, t - 1))

g(u, h, p, t) = fill(1 - h(p, t - 1)[1], 1, 1)
njacs = Ref(0)
function jac(u, h, p, t)
njacs[] += 1
reshape(1 .- h(p, t - 1), 1, 1)
end

nWfacts = Ref(0)
function Wfact(u, h, p, dtgamma, t)
nWfacts[] += 1
reshape(dtgamma .* (1 .- h(p, t - 1)) .- 1, 1, 1)
end

nWfact_ts = Ref(0)
function Wfact_t(u, h, p, dtgamma, t)
nWfact_ts[] += 1
reshape((1 - inv(dtgamma)) .- h(p, t - 1), 1, 1)
end

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{false}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0); constant_lags = [1])
prob_jac = remake(prob; f = DDEFunction{false}(f; jac = jac))
prob_Wfact = remake(prob; f = DDEFunction{false}(f; Wfact = Wfact))
prob_Wfact_t = remake(prob; f = DDEFunction{false}(f; Wfact_t = Wfact_t))

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))
sol = solve(prob, MethodOfSteps(alg))

## Jacobian
njacs[] = 0
sol_jac = solve(prob_jac, MethodOfSteps(alg))

# check number of function evaluations
@test !iszero(njacs[])
@test_broken njacs[] == sol_jac.destats.njacs
if alg isa Rosenbrock23
@test njacs[] == sol_jac.destats.nw
else
@test_broken njacs[] == sol_jac.destats.nw
end

# check resulting solution
@test sol.t ≈ sol_jac.t
@test sol.u ≈ sol_jac.u

## Wfact
nWfacts[] = 0
sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg))

# check number of function evaluations
@test_broken !iszero(nWfacts[])
@test_broken nWfacts[] == njacs[]
@test_broken iszero(sol_Wfact.destats.njacs)
@test_broken nWfacts[] == sol_Wfact.destats.nw

# check resulting solution
@test sol.t ≈ sol_Wfact.t
@test sol.u ≈ sol_Wfact.u

## Wfact_t
nWfact_ts[] = 0
sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg))

# check number of function evaluations
@test_broken !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test_broken iszero(sol_Wfact_ts.destats.njacs)
@test_broken nWfact_ts[] == sol_Wfact_t.destats.nw

@test sol_wo_jac.t ≈ sol_w_jac.t
@test sol_wo_jac.u ≈ sol_w_jac.u
# check resulting solution
@test sol.t ≈ sol_Wfact_t.t
@test sol.u ≈ sol_Wfact_t.u
end
end