Skip to content

Commit

Permalink
Merge pull request #1473 from SciML/inference4
Browse files Browse the repository at this point in the history
Improve inference with vectorcontinuouscallbacks
  • Loading branch information
ChrisRackauckas authored Aug 18, 2021
2 parents 679cfcc + 5fef420 commit 02f85ea
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Adapt = "1.1, 2.0, 3.0"
ArrayInterface = "2.7, 3.0"
DataStructures = "0.18"
DiffEqBase = "6.72"
DiffEqBase = "6.73"
DocStringExtensions = "0.8"
ExponentialUtilities = "1.2"
FastClosures = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,11 @@ DiffEqBase.nlsolve_f(f, alg::DAEAlgorithm) = f
DiffEqBase.nlsolve_f(integrator::ODEIntegrator) =
nlsolve_f(integrator.f, unwrap_alg(integrator, true))

function (integrator::ODEIntegrator)(t,deriv::Type=Val{0};idxs=nothing)
function (integrator::ODEIntegrator)(t,::Type{deriv}=Val{0};idxs=nothing) where {deriv}
current_interpolant(t,integrator,idxs,deriv)
end

(integrator::ODEIntegrator)(val::AbstractArray,t::Union{Number,AbstractArray},deriv::Type=Val{0};idxs=nothing) = current_interpolant!(val,t,integrator,idxs,deriv)
(integrator::ODEIntegrator)(val::AbstractArray,t::Union{Number,AbstractArray},::Type{deriv}=Val{0};idxs=nothing) where {deriv} = current_interpolant!(val,t,integrator,idxs,deriv)

# Interface used by DelayDiffEq
has_tstop(integrator) = !isempty(integrator.opts.tstops)
Expand Down
10 changes: 5 additions & 5 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.
save_end = nothing,
callback = nothing,
dense = save_everystep && !(typeof(alg) <: Union{DAEAlgorithm,FunctionMap}) && isempty(saveat),
calck = (callback !== nothing && callback != CallbackSet()) || (dense) || !isempty(saveat), # and no dense output
calck = (callback !== nothing && callback !== CallbackSet()) || (dense) || !isempty(saveat), # and no dense output
dt = alg isa FunctionMap && isempty(tstops) ? eltype(prob.tspan)(1) : eltype(prob.tspan)(0),
dtmin = nothing,
dtmax = eltype(prob.tspan)((prob.tspan[end]-prob.tspan[1])),
Expand Down Expand Up @@ -207,13 +207,13 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.

callbacks_internal = CallbackSet(callback)

max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
if max_len_cb isa VectorContinuousCallback
max_len_cb = DiffEqBase.max_vector_callback_length_int(callbacks_internal)
if max_len_cb !== nothing
uBottomEltypeReal = real(uBottomEltype)
if isinplace(prob)
callback_cache = DiffEqBase.CallbackCache(u,max_len_cb.len,uBottomEltypeReal,uBottomEltypeReal)
callback_cache = DiffEqBase.CallbackCache(u,max_len_cb,uBottomEltypeReal,uBottomEltypeReal)
else
callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,uBottomEltypeReal,uBottomEltypeReal)
callback_cache = DiffEqBase.CallbackCache(max_len_cb,uBottomEltypeReal,uBottomEltypeReal)
end
else
callback_cache = nothing
Expand Down

0 comments on commit 02f85ea

Please sign in to comment.