Skip to content

Commit

Permalink
Merge pull request #149 from JuliaDiffEq/Wfact
Browse files Browse the repository at this point in the history
Add Wfact and Wfact_t support
  • Loading branch information
ChrisRackauckas authored Aug 28, 2019
2 parents c6695ec + 2ae59d3 commit 1576486
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 34 deletions.
47 changes: 32 additions & 15 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# convenience macro
macro wrap_h(signature)
Meta.isexpr(signature, :call) ||
throw(ArgumentError("signature has to be a function call expression"))

name = signature.args[1]
args = signature.args[2:end]
args_wo_h = [arg for arg in args if arg !== :h]

quote
if f.$name === nothing
nothing
else
if isinplace(f)
let _f = f.$name, h = h
($(args_wo_h...),) -> _f($(args...))
end
else
let _f = f.$name, h = h
($(args_wo_h[2:end]...),) -> _f($(args[2:end]...))
end
end
end
end |> esc
end

struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBase.AbstractODEFunction{iip}
f::F
h::H
Expand All @@ -14,26 +40,17 @@ struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBas
end

function ODEFunctionWrapper(f::DDEFunction, h)
if f.jac === nothing
jac = nothing
else
if isinplace(f)
jac = let f_jac = f.jac, h = h
(J, u, p, t) -> f_jac(J, u, h, p, t)
end
else
jac = let f_jac = f.jac, h = h
(u, p, t) -> f_jac(u, h, p, t)
end
end
end
# wrap functions
jac = @wrap_h jac(J, u, h, p, t)
Wfact = @wrap_h Wfact(W, u, h, p, dtgamma, t)
Wfact_t = @wrap_h Wfact_t(W, u, h, p, dtgamma, t)

ODEFunctionWrapper{isinplace(f),typeof(f.f),typeof(h),typeof(f.mass_matrix),
typeof(f.analytic),typeof(f.tgrad),typeof(jac),
typeof(f.jac_prototype),typeof(f.Wfact),typeof(f.Wfact_t),
typeof(f.jac_prototype),typeof(Wfact),typeof(Wfact_t),
typeof(f.paramjac),typeof(f.syms),typeof(f.colorvec)}(
f.f, h, f.mass_matrix, f.analytic, f.tgrad, jac,
f.jac_prototype, f.Wfact, f.Wfact_t, f.paramjac, f.syms,
f.jac_prototype, Wfact, Wfact_t, f.paramjac, f.syms,
f.colorvec)
end

Expand Down
169 changes: 150 additions & 19 deletions test/interface/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,180 @@ using Test
nothing
end

function g(J, u, h, p, t)
njacs = Ref(0)
function jac(J, u, h, p, t)
njacs[] += 1
J[1, 1] = 1 - h(p, t - 1)[1]
nothing
end

nWfacts = Ref(0)
function Wfact(W, u, h, p, dtgamma, t)
nWfacts[] += 1
W[1,1] = dtgamma * (1 - h(p, t - 1)[1]) - 1
nothing
end

nWfact_ts = Ref(0)
function Wfact_t(W, u, h, p, dtgamma, t)
nWfact_ts[] += 1
W[1,1] = 1 - h(p, t - 1)[1] - inv(dtgamma)
nothing
end

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{true}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0); constant_lags = [1])
prob_jac = remake(prob; f = DDEFunction{true}(f; jac = jac))
prob_Wfact = remake(prob; f = DDEFunction{true}(f; Wfact = Wfact))
prob_Wfact_t = remake(prob; f = DDEFunction{true}(f; Wfact_t = Wfact_t))

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))
sol = solve(prob, MethodOfSteps(alg))

## Jacobian
njacs[] = 0
sol_jac = solve(prob_jac, MethodOfSteps(alg))

# check number of function evaluations
@test !iszero(njacs[])
@test njacs[] == sol_jac.destats.njacs
if alg isa Rosenbrock23
@test njacs[] == sol_jac.destats.nw
else
@test_broken njacs[] == sol_jac.destats.nw
end

# check resulting solution
@test sol.t sol_jac.t
@test sol.u sol_jac.u

## Wfact
nWfacts[] = 0
sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg))

# check number of function evaluations
if alg isa Rosenbrock23
@test !iszero(nWfacts[])
@test nWfacts[] == njacs[]
@test iszero(sol_Wfact.destats.njacs)
else
@test_broken !iszero(nWfacts[])
@test_broken nWfacts[] == njacs[]
@test_broken iszero(sol_Wfact.destats.njacs)
end
@test_broken nWfacts[] == sol_Wfact.destats.nw

# check resulting solution
@test sol.t sol_Wfact.t
@test sol.u sol_Wfact.u

## Wfact_t
nWfact_ts[] = 0
sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg))

# check number of function evaluations
if alg isa Rosenbrock23
@test_broken !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test_broken iszero(sol_Wfact_t.destats.njacs)
else
@test !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test iszero(sol_Wfact_t.destats.njacs)
end
@test_broken nWfact_ts[] == sol_Wfact_t.destats.nw

@test sol_wo_jac.t sol_w_jac.t
@test sol_wo_jac.u sol_w_jac.u
# check resulting solution
if alg isa Rosenbrock23
@test sol.t sol_Wfact_t.t
@test sol.u sol_Wfact_t.u
else
@test_broken sol.t sol_Wfact_t.t
@test_broken sol.u sol_Wfact_t.u
end
end
end

@testset "out-of-place" begin
# define functions (Hutchinson's equation)
f(u, h, p, t) = [u[1] * (1 - h(p, t - 1)[1])]
f(u, h, p, t) = u[1] .* (1 .- h(p, t - 1))

g(u, h, p, t) = fill(1 - h(p, t - 1)[1], 1, 1)
njacs = Ref(0)
function jac(u, h, p, t)
njacs[] += 1
reshape(1 .- h(p, t - 1), 1, 1)
end

nWfacts = Ref(0)
function Wfact(u, h, p, dtgamma, t)
nWfacts[] += 1
reshape(dtgamma .* (1 .- h(p, t - 1)) .- 1, 1, 1)
end

nWfact_ts = Ref(0)
function Wfact_t(u, h, p, dtgamma, t)
nWfact_ts[] += 1
reshape((1 - inv(dtgamma)) .- h(p, t - 1), 1, 1)
end

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{false}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0); constant_lags = [1])
prob_jac = remake(prob; f = DDEFunction{false}(f; jac = jac))
prob_Wfact = remake(prob; f = DDEFunction{false}(f; Wfact = Wfact))
prob_Wfact_t = remake(prob; f = DDEFunction{false}(f; Wfact_t = Wfact_t))

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))
sol = solve(prob, MethodOfSteps(alg))

## Jacobian
njacs[] = 0
sol_jac = solve(prob_jac, MethodOfSteps(alg))

# check number of function evaluations
@test !iszero(njacs[])
@test_broken njacs[] == sol_jac.destats.njacs
if alg isa Rosenbrock23
@test njacs[] == sol_jac.destats.nw
else
@test_broken njacs[] == sol_jac.destats.nw
end

# check resulting solution
@test sol.t sol_jac.t
@test sol.u sol_jac.u

## Wfact
nWfacts[] = 0
sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg))

# check number of function evaluations
@test_broken !iszero(nWfacts[])
@test_broken nWfacts[] == njacs[]
@test_broken iszero(sol_Wfact.destats.njacs)
@test_broken nWfacts[] == sol_Wfact.destats.nw

# check resulting solution
@test sol.t sol_Wfact.t
@test sol.u sol_Wfact.u

## Wfact_t
nWfact_ts[] = 0
sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg))

# check number of function evaluations
@test_broken !iszero(nWfact_ts[])
@test_broken nWfact_ts[] == njacs[]
@test_broken iszero(sol_Wfact_ts.destats.njacs)
@test_broken nWfact_ts[] == sol_Wfact_t.destats.nw

@test sol_wo_jac.t sol_w_jac.t
@test sol_wo_jac.u sol_w_jac.u
# check resulting solution
@test sol.t sol_Wfact_t.t
@test sol.u sol_Wfact_t.u
end
end

0 comments on commit 1576486

Please sign in to comment.