Skip to content

Commit

Permalink
Use user-defined Jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jul 29, 2019
1 parent 51992fd commit 08232f7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,26 @@ struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBas
colorvec::TCV
end

# TODO: make use of other functions
function ODEFunctionWrapper(f::DDEFunction, h)
if f.jac === nothing
jac = nothing
else
if isinplace(f)
jac = let f_jac = f.jac, h = h
(J, u, p, t) -> f_jac(J, u, h, p, t)
end
else
jac = let f_jac = f.jac, h = h
(u, p, t) -> f_jac(u, h, p, t)
end
end
end

ODEFunctionWrapper{isinplace(f),typeof(f.f),typeof(h),typeof(f.mass_matrix),
typeof(f.analytic),typeof(f.tgrad),typeof(f.jac),
typeof(f.analytic),typeof(f.tgrad),typeof(jac),
typeof(f.jac_prototype),typeof(f.Wfact),typeof(f.Wfact_t),
typeof(f.paramjac),typeof(f.syms),typeof(f.colorvec)}(
f.f, h, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.f, h, f.mass_matrix, f.analytic, f.tgrad, jac,
f.jac_prototype, f.Wfact, f.Wfact_t, f.paramjac, f.syms,
f.colorvec)
end
Expand Down
56 changes: 56 additions & 0 deletions test/jacobian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using DelayDiffEq, Test

@testset "in-place" begin
# define functions (Hutchinson's equation)
function f(du, u, h, p, t)
du[1] = u[1] * (1 - h(p, t - 1)[1])
nothing
end

function g(J, u, h, p, t)
println("t")
J[1, 1] = 1 - h(p, t - 1)[1]
nothing
end

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{true}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))

@test sol_wo_jac.t sol_w_jac.t
@test sol_wo_jac.u sol_w_jac.u
end
end

@testset "out-of-place" begin
# define functions (Hutchinson's equation)
f(u, h, p, t) = [u[1] * (1 - h(p, t - 1)[1])]

g(u, h, p, t) = (println("t"); fill(1 - h(p, t - 1)[1], 1, 1))

h(p, t) = [0.0]

# define problems
prob_wo_jac = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0);
constant_lags = [1])
prob_w_jac = DDEProblem(DDEFunction{false}(f; jac = g), [1.0], h, (0.0, 40.0);
constant_lags = [1])

# compute solutions
for alg in (Rosenbrock23(), TRBDF2())
sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg))
sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg))

@test sol_wo_jac.t sol_w_jac.t
@test sol_wo_jac.u sol_w_jac.u
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ if GROUP == "All" || GROUP == "Interface"
@time @safetestset "Discontinuity Tests" begin include("discontinuities.jl") end
@time @safetestset "History Function Tests" begin include("history_function.jl") end
@time @safetestset "Parameterized Function Tests" begin include("parameters.jl") end
@time @safetestset "Jacobian Tests" begin include("jacobian.jl") end
@time @safetestset "Return Code Tests" begin include("retcode.jl") end
@time @safetestset "Composite Solution Tests" begin include("composite_solution.jl") end
@time @safetestset "Dependent Delay Tests" begin include("dependent_delays.jl") end
Expand Down

0 comments on commit 08232f7

Please sign in to comment.