Skip to content

Commit

Permalink
Rosenbrock refactor Rodas5*
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 18, 2024
1 parent b0f957c commit 0054db2
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 1,542 deletions.
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
RosenbrockCache, Rodas23WCache, Rodas3PCache}}
dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" :
"1st order linear"
Expand All @@ -20,8 +20,8 @@ end
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{Rosenbrock5ConstantCache,
Rosenbrock5Cache}}
Union{RosenbrockCombinedConstantCache,
RosenbrockCache}}
dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" :
"1st order linear"
end
122 changes: 27 additions & 95 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
order::Int
end
function full_cache(c::RosenbrockCache)
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
end

struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
order::Int
end

@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
Expand Down Expand Up @@ -702,16 +714,6 @@ end

### Rodas4 methods

struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
end

tabtype(::Rodas4) = Rodas4Tableau
tabtype(::Rodas42) = Rodas42Tableau
tabtype(::Rodas4P) = Rodas4PTableau
Expand All @@ -727,10 +729,10 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rodas4ConstantCache(tf, uf,
RosenbrockCombinedConstantCache(tf, uf,
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve,
alg_autodiff(alg))
alg_autodiff(alg), 4)
end

function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
Expand Down Expand Up @@ -783,81 +785,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
alg.step_limiter!, alg.stage_limiter!)
alg.step_limiter!, alg.stage_limiter!, 4)
end

################################################################################

### Rosenbrock5

struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

@cache mutable struct Rosenbrock5Cache{
uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
dense1::rateType
dense2::rateType
dense3::rateType
du::rateType
du1::rateType
du2::rateType
k1::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
k6::rateType
k7::rateType
k8::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
dense = [zero(rate_prototype) for _ in 1:3]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
k6 = zero(rate_prototype)
k7 = zero(rate_prototype)
k8 = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:7]
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
Expand All @@ -881,12 +824,11 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
alg.stage_limiter!, 5)
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -898,30 +840,21 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rosenbrock5ConstantCache(tf, uf,
RosenbrockCombinedConstantCache(tf, uf,
Rodas5Tableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve)
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
end

function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
dense = [zero(rate_prototype) for _ in 1:3]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
k6 = zero(rate_prototype)
k7 = zero(rate_prototype)
k8 = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:8]
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
Expand All @@ -945,12 +878,11 @@ function alg_cache(
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
alg.stage_limiter!, 5)
end

function alg_cache(
Expand All @@ -963,9 +895,9 @@ function alg_cache(
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rosenbrock5ConstantCache(tf, uf,
RosenbrockCombinedConstantCache(tf, uf,
Rodas5PTableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve)
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
end

function get_fsalfirstlast(
Expand Down
Loading

0 comments on commit 0054db2

Please sign in to comment.