Skip to content

Commit

Permalink
Merge pull request #451 from JuliaDiffEq/myb/nlsolve
Browse files Browse the repository at this point in the history
Restructure nonlinear solvers
  • Loading branch information
ChrisRackauckas authored Aug 4, 2018
2 parents fe202c8 + 54a8998 commit 9fe2615
Show file tree
Hide file tree
Showing 24 changed files with 2,086 additions and 3,187 deletions.
8 changes: 5 additions & 3 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ module OrdinaryDiffEq

include("misc_utils.jl")
include("algorithms.jl")
include("nlsolve/type.jl")
include("nlsolve/newton.jl")
include("nlsolve/functional.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand All @@ -66,7 +69,6 @@ module OrdinaryDiffEq
include("caches/adams_bashforth_moulton_caches.jl")
include("caches/nordsieck_caches.jl")
include("caches/bdf_caches.jl")
include("caches/sbdf_caches.jl")
include("caches/rkc_caches.jl")
include("caches/euler_imex_caches.jl")

Expand Down Expand Up @@ -108,7 +110,6 @@ module OrdinaryDiffEq
include("perform_step/adams_bashforth_moulton_perform_step.jl")
include("perform_step/nordsieck_perform_step.jl")
include("perform_step/bdf_perform_step.jl")
include("perform_step/sbdf_perform_step.jl")
include("perform_step/rkc_perform_step.jl")
include("perform_step/euler_imex_perform_step.jl")

Expand All @@ -121,7 +122,6 @@ module OrdinaryDiffEq
include("dense/high_order_rk_addsteps.jl")

include("derivative_utils.jl")
include("nlsolve_utils.jl")
include("nordsieck_utils.jl")
include("adams_utils.jl")
include("bdf_utils.jl")
Expand Down Expand Up @@ -200,4 +200,6 @@ module OrdinaryDiffEq

export AutoSwitch, AutoTsit5, AutoDP5,
AutoVern6, AutoVern7, AutoVern8, AutoVern9

export NLNewton, NLAnderson, NLFunctional
end # module
410 changes: 133 additions & 277 deletions src/algorithms.jl

Large diffs are not rendered by default.

146 changes: 28 additions & 118 deletions src/caches/adams_bashforth_moulton_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -922,18 +922,15 @@ end

# CNAB2

mutable struct CNAB2ConstantCache{rateType,F,uToltype,uType,tType} <: OrdinaryDiffEqConstantCache
mutable struct CNAB2ConstantCache{rateType,F,N,uType,tType} <: OrdinaryDiffEqConstantCache
k2::rateType
uf::F
ηold::uToltype
κ::uToltype
tol::uToltype
newton_iters::Int
nlsolve::N
uprev3::uType
tprev2::tType
end

mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,N,tType,F} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -953,10 +950,7 @@ mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F
uf::UF
jac_config::JC
linsolve::F
ηold::uToltype
κ::uToltype
tol::uToltype
newton_iters::Int
nlsolve::N
uprev3::uType
tprev2::tType
end
Expand All @@ -965,90 +959,48 @@ u_cache(c::CNAB2Cache) = ()
du_cache(c::CNAB2Cache) = ()

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
k2 = rate_prototype
uToltype = real(uBottomEltypeNoUnits)
uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p)
ηold = one(uToltype)
uf != nothing && ( uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p) )
uprev3 = u
tprev2 = t

if alg.κ != nothing
κ = uToltype(alg.κ)
else
κ = uToltype(1//100)
end
if alg.tol != nothing
tol = uToltype(alg.tol)
else
tol = uToltype(min(0.03,first(reltol)^(0.5)))
end

CNAB2ConstantCache(k2,uf,ηold,κ,tol,10000,uprev3,tprev2)
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,1//2,1,ηold,z₊,dz,tmp,b,k))
CNAB2ConstantCache(k2,uf,nlsolve,uprev3,tprev2)
end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
tmp = similar(u); b = similar(u,axes(u));
@iipnlcachefields
atmp = similar(u,uEltypeNoUnits,axes(u))
fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
du1 = zero(rate_prototype)
du₁ = zero(rate_prototype)

if typeof(f) <: SplitFunction
if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
else
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
end

linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)

uToltype = real(uBottomEltypeNoUnits)
if alg.κ != nothing
κ = uToltype(alg.κ)
else
κ = uToltype(1//100)
end
if alg.tol != nothing
tol = uToltype(alg.tol)
else
tol = uToltype(min(0.03,first(reltol)^(0.5)))
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end

uprev3 = similar(u)
tprev2 = t

ηold = one(uToltype)

CNAB2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,ηold,κ,tol,10000,uprev3,tprev2)
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,1//2,1,ηold,z₊,dz,tmp,b,k))
CNAB2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,nlsolve,uprev3,tprev2)
end

# CNLF2

mutable struct CNLF2ConstantCache{rateType,F,uToltype,uType,tType} <: OrdinaryDiffEqConstantCache
mutable struct CNLF2ConstantCache{rateType,F,N,uType,tType} <: OrdinaryDiffEqConstantCache
k2::rateType
uf::F
ηold::uToltype
κ::uToltype
tol::uToltype
newton_iters::Int
nlsolve::N
uprev2::uType
uprev3::uType
tprev2::tType
end

mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,N,tType,F} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -1068,10 +1020,7 @@ mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F
uf::UF
jac_config::JC
linsolve::F
ηold::uToltype
κ::uToltype
tol::uToltype
newton_iters::Int
nlsolve::N
uprev3::uType
tprev2::tType
end
Expand All @@ -1080,73 +1029,34 @@ u_cache(c::CNLF2Cache) = ()
du_cache(c::CNLF2Cache) = ()

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
k2 = rate_prototype
uToltype = real(uBottomEltypeNoUnits)
uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p)
ηold = one(uToltype)
uf != nothing && ( uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p) )
uprev2 = u
uprev3 = u
tprev2 = t

if alg.κ != nothing
κ = uToltype(alg.κ)
else
κ = uToltype(1//100)
end
if alg.tol != nothing
tol = uToltype(alg.tol)
else
tol = uToltype(min(0.03,first(reltol)^(0.5)))
end

CNLF2ConstantCache(k2,uf,ηold,κ,tol,10000,uprev2,uprev3,tprev2)
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,1//1,1,ηold,z₊,dz,tmp,b,k))
CNLF2ConstantCache(k2,uf,nlsolve,uprev2,uprev3,tprev2)
end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
tmp = similar(u); b = similar(u,axes(u));
@iipnlcachefields
atmp = similar(u,uEltypeNoUnits,axes(u))
fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
du₁ = zero(rate_prototype)
du1 = zero(rate_prototype)

if typeof(f) <: SplitFunction
if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
else
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
end

linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)

uToltype = real(uBottomEltypeNoUnits)
if alg.κ != nothing
κ = uToltype(alg.κ)
else
κ = uToltype(1//100)
end
if alg.tol != nothing
tol = uToltype(alg.tol)
else
tol = uToltype(min(0.03,first(reltol)^(0.5)))
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end

uprev2 = similar(u)
uprev3 = similar(u)
tprev2 = t

ηold = one(uToltype)

CNLF2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,ηold,κ,tol,10000,uprev3,tprev2)
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,1//1,1,ηold,z₊,dz,tmp,b,k))
CNLF2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,nlsolve,uprev3,tprev2)
end
Loading

0 comments on commit 9fe2615

Please sign in to comment.