Skip to content

Commit

Permalink
Merge pull request #467 from JuliaDiffEq/interpolants
Browse files Browse the repository at this point in the history
Refactor interpolants
  • Loading branch information
ChrisRackauckas authored Aug 5, 2018
2 parents 7a1af96 + 36bacb4 commit 896e650
Show file tree
Hide file tree
Showing 5 changed files with 1,112 additions and 1,488 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ module OrdinaryDiffEq
set_abstol!, postamble!, last_step_failed,
isautodifferentiable

using DiffEqBase: check_error!
using DiffEqBase: check_error!, @def

macro tight_loop_macros(ex)
:($(esc(ex)))
Expand Down
290 changes: 125 additions & 165 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,17 @@ function ode_addsteps!(k,t,uprev,u,dt,f,p,cache,always_calc_begin = false,allow_
nothing
end

"""
ode_interpolant and ode_interpolant! dispatch
"""
function ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
_ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
end

function ode_interpolant(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idxs,T::Type{Val{TI}}) where TI
if typeof(idxs) <: Number || typeof(y₀) <: Number
return ode_interpolant!(nothing,Θ,dt,y₀,y₁,k,cache,idxs,T)
if typeof(idxs) <: Number || typeof(y₀) <: Union{Number,SArray}
# typeof(y₀) can be these if saveidxs gives a single value
_ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
else
# determine output type
# required for calculation of time derivatives with autodifferentiation
Expand All @@ -342,236 +350,188 @@ function ode_interpolant(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idx
else
out = similar(y₀, S, axes(idxs))
end
ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
return out
_ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
end
end

function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
_ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
end

##################### Hermite Interpolants

# If no dispatch found, assume Hermite
function ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
function _ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T)
end

function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
function _ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}}) where TI
hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T)
end

"""
Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
if typeof(idxs) <: Nothing
#out = @. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
out = (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
else
#out = similar(y₀,axes(idxs))
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@views out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
(1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
# return @. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
@inbounds for i in eachindex(out)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
if typeof(idxs) <: Nothing
#out = @. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
else
#out = similar(y₀,axes(idxs))
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@views out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
if typeof(idxs) <: Nothing
#out = @. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
else
#out = similar(y₀,axes(idxs))
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@views out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
# return @. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
return k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
@inbounds for i in eachindex(out)
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
if typeof(idxs) <: Nothing
#out = @. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
else
#out = similar(y₀,axes(idxs))
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
@views out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@inbounds for (j,i) in enumerate(idxs)
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
end

"""
Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
if out == nothing
if idxs == nothing
# return @. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
return (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
else
# return @. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
end
elseif idxs == nothing
#@. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
@inbounds for i in eachindex(out)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
else
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@views out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
@inbounds for i in eachindex(out)
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
if out == nothing
if idxs == nothing
# return @. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
return k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
else
# return @. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
return k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
end
elseif idxs == nothing
#@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
@inbounds for i in eachindex(out)
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
else
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@inbounds for (j,i) in enumerate(idxs)
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@inbounds for (j,i) in enumerate(idxs)
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
if out == nothing
if idxs == nothing
return (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
else
return (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
end
elseif idxs == nothing
#@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
@inbounds for i in eachindex(out)
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
else
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@inbounds for (j,i) in enumerate(idxs)
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
#@. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
@views out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
# @. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
for i in eachindex(out)
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
if out == nothing
if idxs == nothing
# return @. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
return (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
else
# return @. (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
return (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
end
elseif idxs == nothing
# @. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
for i in eachindex(out)
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
else
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
for (j,i) in enumerate(idxs)
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
for (j,i) in enumerate(idxs)
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
end

######################## Linear Interpolants

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. Θm1*y₀ + Θ*y₁
end

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
if typeof(idxs) <: Nothing
out = @. Θm1*y₀ + Θ*y₁
else
out = @. Θm1*y₀[idxs] + Θ*y₁[idxs]
end
out
@. Θm1*y₀[idxs] + Θ*y₁[idxs]
end

function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
if typeof(idxs) <: Nothing
out = @. (y₁ - y₀)/dt
else
out = @. (y₁[idxs] - y₀[idxs])/dt
end
@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. out = Θm1*y₀ + Θ*y₁
out
end

"""
Linear Interpolation
"""
@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
if out == nothing
if idxs == nothing
return @. Θm1*y₀ + Θ*y₁
else
return @. Θm1*y₀[idxs] + Θ*y₁[idxs]
end
elseif idxs == nothing
@. out = Θm1*y₀ + Θ*y₁
else
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
end
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
out
end

"""
Linear Interpolation
"""
function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. (y₁ - y₀)/dt
end

function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
@. (y₁[idxs] - y₀[idxs])/dt
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. out = (y₁ - y₀)/dt
out
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
if out == nothing
if idxs == nothing
return @. (y₁ - y₀)/dt
else
return @. (y₁[idxs] - y₀[idxs])/dt
end
elseif idxs == nothing
@. out = (y₁ - y₀)/dt
else
@views @. out = (y₁[idxs] - y₀[idxs])/dt
end
@views @. out = (y₁[idxs] - y₀[idxs])/dt
out
end
Loading

0 comments on commit 896e650

Please sign in to comment.