Skip to content

Commit

Permalink
Merge pull request #2297 from oscardssmith/os/always-w-transform
Browse files Browse the repository at this point in the history
Make W_transform always true
  • Loading branch information
ChrisRackauckas authored Sep 17, 2024
2 parents 09aa469 + 783c88a commit b0f957c
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 265 deletions.
272 changes: 55 additions & 217 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,11 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
nothing
end

function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2,
::Val{transform} = Val(true)) where {transform, F1, F2}
function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2) where {F1, F2}
haslinsolve = hasfield(typeof(alg), :linsolve)

if !DiffEqBase.has_jac(f) && # No Jacobian if has analytical solution
(transform || !DiffEqBase.has_Wfact(f)) && # No Jacobian if has_Wfact and Wfact is the one that's used
(!transform || !DiffEqBase.has_Wfact_t(f)) && # No Jacobian has_Wfact and Wfact_t is the one that's used
(!DiffEqBase.has_Wfact_t(f)) &&
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
(concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for
Expand Down
32 changes: 15 additions & 17 deletions lib/OrdinaryDiffEqExtrapolation/src/extrapolation_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
calc_J!(J, integrator, cache) # Store the calculated jac as it won't change in internal discretisation
for index in 1:(n_curr + 1)
dt_temp = dt / sequence[index]
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J)
integrator.stats.nw += 1
@.. broadcast=false k_tmps[1]=integrator.fsalfirst
@.. broadcast=false u_tmps[1]=uprev
Expand Down Expand Up @@ -344,9 +344,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
endIndex = (i == 1) ? n_curr : n_curr + 1
for index in startIndex:endIndex
dt_temp = dt / sequence[index]
jacobian2W!(
W[Threads.threadid()], integrator.f.mass_matrix, dt_temp, J,
true)
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix, dt_temp, J)
@.. broadcast=false k_tmps[Threads.threadid()]=integrator.fsalfirst
@.. broadcast=false u_tmps[Threads.threadid()]=uprev
for j in 1:sequence[index]
Expand Down Expand Up @@ -445,7 +443,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
cache.n_curr = n_curr

dt_temp = dt / sequence[n_curr + 1]
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J, false)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J)
integrator.stats.nw += 1
@.. broadcast=false k_tmps[1]=integrator.fsalfirst
@.. broadcast=false u_tmps[1]=uprev
Expand Down Expand Up @@ -1170,7 +1168,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
for i in 0:n_curr
j_int = 4 * subdividing_sequence[i + 1]
dt_int = dt / j_int # Stepsize of the ith internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down Expand Up @@ -1241,7 +1239,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
j_int_temp = 4 * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
dt_int_temp, J, true)
dt_int_temp, J)
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst

Expand Down Expand Up @@ -1326,7 +1324,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
j_int_temp = 4 * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
dt_int_temp, J, true)
dt_int_temp, J)
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst

Expand Down Expand Up @@ -1450,7 +1448,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
# Update cache.T
j_int = 4 * subdividing_sequence[n_curr + 1]
dt_int = dt / j_int # Stepsize of the new internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down Expand Up @@ -2536,7 +2534,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
for i in 0:n_curr
j_int = 4 * subdividing_sequence[i + 1]
dt_int = dt / j_int # Stepsize of the ith internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down Expand Up @@ -2610,7 +2608,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
j_int_temp = 4 * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
dt_int_temp, J, true)
dt_int_temp, J)
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst

Expand Down Expand Up @@ -2701,7 +2699,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
index == -1 && continue
j_int_temp = 4 * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[tid], integrator.f.mass_matrix, dt_int_temp, J, true)
jacobian2W!(W[tid], integrator.f.mass_matrix, dt_int_temp, J)
@.. broadcast=false u_temp4[tid]=uprev
@.. broadcast=false linsolvetmp=fsalfirst

Expand Down Expand Up @@ -2815,7 +2813,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
# Update cache.T
j_int = 4 * subdividing_sequence[n_curr + 1]
dt_int = dt / j_int # Stepsize of the new internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down Expand Up @@ -3227,7 +3225,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
for i in 0:n_curr
j_int = sequence_factor * subdividing_sequence[i + 1]
dt_int = dt / j_int # Stepsize of the ith internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down Expand Up @@ -3301,7 +3299,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
j_int_temp = sequence_factor * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
dt_int_temp, J, true)
dt_int_temp, J)
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst

Expand Down Expand Up @@ -3389,7 +3387,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
j_int_temp = sequence_factor * subdividing_sequence[index + 1]
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
dt_int_temp, J, true)
dt_int_temp, J)
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst

Expand Down Expand Up @@ -3519,7 +3517,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
# Update cache.T
j_int = sequence_factor * subdividing_sequence[n_curr + 1]
dt_int = dt / j_int # Stepsize of the new internal discretisation
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
integrator.stats.nw += 1
@.. broadcast=false u_temp2=uprev
@.. broadcast=false linsolve_tmps[1]=fsalfirst
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ function gen_constant_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachena

# Time derivative
tf.u = uprev
dT = ForwardDiff.derivative(tf, t)
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
linsolve_tmp = integrator.fsalfirst + dtd1*dT #calc_rosenbrock_differentiation!

$(iterexprs...)
Expand Down Expand Up @@ -476,7 +476,7 @@ function gen_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachename::Symbo
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t)

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)

linsolve = cache.linsolve

Expand Down
28 changes: 14 additions & 14 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -155,7 +155,7 @@ end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -259,7 +259,7 @@ end
# Time derivative
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtγ, repeat_step, true)
W = calc_W(integrator, cache, dtγ, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -338,7 +338,7 @@ end
# Time derivative
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtγ, repeat_step, true)
W = calc_W(integrator, cache, dtγ, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -444,7 +444,7 @@ end
# Time derivative
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -515,7 +515,7 @@ end
dtd3 = dt * d3
dtgamma = dt * gamma

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -623,7 +623,7 @@ end
tf.u = uprev
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -710,7 +710,7 @@ end
dtd4 = dt * d4
dtgamma = dt * gamma

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -876,7 +876,7 @@ end
tf.u = uprev
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -1018,7 +1018,7 @@ end
f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation!
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -1226,7 +1226,7 @@ end
tf.u = uprev
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -1317,7 +1317,7 @@ end
f(cache.fsalfirst, uprev, p, t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down Expand Up @@ -1449,7 +1449,7 @@ end
# Time derivative
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
W = calc_W(integrator, cache, dtgamma, repeat_step)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
Expand Down Expand Up @@ -1662,7 +1662,7 @@ end
f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation!
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,

### Jacobian does not need to be re-evaluated after an event
### Since it's unchanged
jacobian2W!(W, mass_matrix, dtγ, J, true)
jacobian2W!(W, mass_matrix, dtγ, J)

linsolve = cache.linsolve

Expand Down Expand Up @@ -215,7 +215,7 @@ function _ode_addsteps!(

### Jacobian does not need to be re-evaluated after an event
### Since it's unchanged
jacobian2W!(W, mass_matrix, dtgamma, J, true)
jacobian2W!(W, mass_matrix, dtgamma, J)

linsolve = cache.linsolve

Expand Down Expand Up @@ -394,7 +394,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCache,
@.. linsolve_tmp = @muladd fsalfirst + dtgamma * dT

# Jacobian does not need to be re-evaluated after an event since it's unchanged
jacobian2W!(W, mass_matrix, dtgamma, J, true)
jacobian2W!(W, mass_matrix, dtgamma, J)

linsolve = cache.linsolve

Expand Down Expand Up @@ -623,7 +623,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock5Cache,

### Jacobian does not need to be re-evaluated after an event
### Since it's unchanged
jacobian2W!(W, mass_matrix, dtgamma, J, true)
jacobian2W!(W, mass_matrix, dtgamma, J)

linsolve = cache.linsolve

Expand Down
4 changes: 2 additions & 2 deletions test/interface/utility_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: WOperator, calc_W, calc_W!,
tspan = (0.0, 1.0)
dt = 0.01
dtgamma = 0.5dt
concrete_W = -mm + dtgamma * A
concrete_W = A - inv(dtgamma)*mm

# Out-of-place
fun = ODEFunction((u, p, t) -> A * u;
Expand Down Expand Up @@ -39,7 +39,7 @@ using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: WOperator, calc_W, calc_W!,

# But jacobian2W! will update the cache
jacobian2W!(integrator.cache.nlsolver.cache.W._concrete_form, mm,
dtgamma, integrator.cache.nlsolver.cache.W.J.A, false)
dtgamma, integrator.cache.nlsolver.cache.W.J.A)
@test convert(AbstractMatrix, integrator.cache.nlsolver.cache.W) == concrete_W
ldiv!(tmp, lu!(integrator.cache.nlsolver.cache.W), u0)
@test tmp == concrete_W \ u0
Expand Down
6 changes: 2 additions & 4 deletions test/interface/wprototype_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ for prob in (prob_ode_vanderpol_stiff,)
update_func = (old_val, u, p, t; dtgamma) -> dtgamma,
accepted_kwargs = (:dtgamma,))
transform_op = ScalarOperator(0.0;
update_func = (old_op, u, p, t; dtgamma, transform) -> transform ?
inv(dtgamma) :
one(dtgamma),
accepted_kwargs = (:dtgamma, :transform))
update_func = (old_op, u, p, t; dtgamma) -> inv(dtgamma),
accepted_kwargs = (:dtgamma,))
W_op = -(I - gamma_op * J_op) * transform_op

# Make problem with custom MatrixOperator jac_prototype
Expand Down

0 comments on commit b0f957c

Please sign in to comment.