Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to the fast broadcast implementation #716

Merged
merged 12 commits into from
Apr 13, 2019
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ module OrdinaryDiffEq
set_abstol!, postamble!, last_step_failed,
isautodifferentiable

using DiffEqBase: check_error!, @def
using DiffEqBase: check_error!, @def, @..

macro tight_loop_macros(ex)
:($(esc(ex)))
Expand Down
14 changes: 7 additions & 7 deletions src/adams_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function ϕ_and_ϕstar!(cache, du, k)
β[i] = β[i-1] * ξ/ξ0
ξ += dts[i]
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -47,8 +47,8 @@ function ϕ_and_ϕstar!(cache::Union{VCABMConstantCache,VCABMCache}, du, k)
β[i] = β[i-1] * ξ/ξ0
ξ += dts[i]
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -64,8 +64,8 @@ function expand_ϕ_and_ϕstar!(cache, i)
ξ0 += dts[i]
β[i] = β[i-1] * ξ/ξ0
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -78,7 +78,7 @@ function ϕ_np1!(cache, du_np1, k)
for i = 1:k
if i != 1
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
@.. ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
else
ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
end
Expand Down
4 changes: 2 additions & 2 deletions src/bdf_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function backward_diff!(cache::OrdinaryDiffEqMutableCache, D, D2, k, flag=true)
flag && copyto!(D[1], D2[1,1])
for i = 2:k
for j = 1:(k-i+1)
@. D2[i,j] = D2[i-1,j] - D2[i-1,j+1]
@.. D2[i,j] = D2[i-1,j] - D2[i-1,j+1]
end
flag && copyto!(D[i], D2[i,1])
end
Expand All @@ -52,7 +52,7 @@ function reinterpolate_history!(cache::OrdinaryDiffEqMutableCache, D, R, k)
fill!(tmp,zero(eltype(D[1])))
for j = 1:k
for k = 1:k
@. tmp += D[k] * R[k,j]
@.. tmp += D[k] * R[k,j]
end
D[j] .= tmp
fill!(tmp, zero(eltype(tmp)))
Expand Down
112 changes: 56 additions & 56 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,163 +389,163 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
#@.. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
(1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
# return @. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
# return @.. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
@inbounds for i in eachindex(out)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
@.. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
#@inbounds for i in eachindex(out)
# out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
@views @.. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
#@.. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
# return @. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
# return @.. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
return k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
@inbounds for i in eachindex(out)
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
@.. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
#@inbounds for i in eachindex(out)
# out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@inbounds for (j,i) in enumerate(idxs)
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
@views @.. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
#@.. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
#@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@views out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
@inbounds for i in eachindex(out)
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
@.. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
#@inbounds for i in eachindex(out)
# out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@inbounds for (j,i) in enumerate(idxs)
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
#@. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
#@.. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
#@views @.. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
@views out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
# @. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
for i in eachindex(out)
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
@.. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
#for i in eachindex(out)
# out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
for (j,i) in enumerate(idxs)
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
@views @.. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
#for (j,i) in enumerate(idxs)
# out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
#end
#out
end

######################## Linear Interpolants

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. Θm1*y₀ + Θ*y₁
@.. Θm1*y₀ + Θ*y₁
end

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
@. Θm1*y₀[idxs] + Θ*y₁[idxs]
@.. Θm1*y₀[idxs] + Θ*y₁[idxs]
end

@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. out = Θm1*y₀ + Θ*y₁
@.. out = Θm1*y₀ + Θ*y₁
out
end

@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
@views @.. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
out
end

"""
Linear Interpolation
"""
function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. (y₁ - y₀)/dt
@.. (y₁ - y₀)/dt
end

function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
@. (y₁[idxs] - y₀[idxs])/dt
@.. (y₁[idxs] - y₀[idxs])/dt
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. out = (y₁ - y₀)/dt
@.. out = (y₁ - y₀)/dt
out
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
@views @. out = (y₁[idxs] - y₀[idxs])/dt
@views @.. out = (y₁[idxs] - y₀[idxs])/dt
out
end
50 changes: 25 additions & 25 deletions src/dense/high_order_rk_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
end
end

#=
@muladd function DiffEqBase.addsteps!(k,t,uprev,u,dt,f,p,cache::DP8Cache,always_calc_begin = false,allow_calc_end = true,force_calc_end = false)
if length(k)<7 || always_calc_begin
@unpack c7,c8,c9,c10,c11,c6,c5,c4,c3,c2,b1,b6,b7,b8,b9,b10,b11,b12,a0201,a0301,a0302,a0401,a0403,a0501,a0503,a0504,a0601,a0604,a0605,a0701,a0704,a0705,a0706,a0801,a0804,a0805,a0806,a0807,a0901,a0904,a0905,a0906,a0907,a0908,a1001,a1004,a1005,a1006,a1007,a1008,a1009,a1101,a1104,a1105,a1106,a1107,a1108,a1109,a1110,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache.tab
Expand All @@ -43,48 +42,48 @@ end
utmp = utilde
k = [cache.udiff,cache.bspl,cache.dense_tmp3,cache.dense_tmp4,cache.dense_tmp5,cache.dense_tmp6,cache.dense_tmp7]
f(k1,uprev,p,t)
@. tmp = uprev+dt*(a0201*k1)
@.. tmp = uprev+dt*(a0201*k1)
f(k2,tmp,p,t+c2*dt)
@. tmp = uprev+dt*(a0301*k1+a0302*k2)
@.. tmp = uprev+dt*(a0301*k1+a0302*k2)
f(k3,tmp,p,t+c3*dt)
@. tmp = uprev+dt*(a0401*k1+a0403*k3)
@.. tmp = uprev+dt*(a0401*k1+a0403*k3)
f(k4,tmp,p,t+c4*dt)
@. tmp = uprev+dt*(a0501*k1+a0503*k3+a0504*k4)
@.. tmp = uprev+dt*(a0501*k1+a0503*k3+a0504*k4)
f(k5,tmp,p,t+c5*dt)
@. tmp = uprev+dt*(a0601*k1+a0604*k4+a0605*k5)
@.. tmp = uprev+dt*(a0601*k1+a0604*k4+a0605*k5)
f(k6,tmp,p,t+c6*dt)
@. tmp = uprev+dt*(a0701*k1+a0704*k4+a0705*k5+a0706*k6)
@.. tmp = uprev+dt*(a0701*k1+a0704*k4+a0705*k5+a0706*k6)
f(k7,tmp,p,t+c7*dt)
@. tmp = uprev+dt*(a0801*k1+a0804*k4+a0805*k5+a0806*k6+a0807*k7)
@.. tmp = uprev+dt*(a0801*k1+a0804*k4+a0805*k5+a0806*k6+a0807*k7)
f(k8,tmp,p,t+c8*dt)
@. tmp = uprev+dt*(a0901*k1+a0904*k4+a0905*k5+a0906*k6+a0907*k7+a0908*k8)
@.. tmp = uprev+dt*(a0901*k1+a0904*k4+a0905*k5+a0906*k6+a0907*k7+a0908*k8)
f(k9,tmp,p,t+c9*dt)
@. tmp = uprev+dt*(a1001*k1+a1004*k4+a1005*k5+a1006*k6+a1007*k7+a1008*k8+a1009*k9)
@.. tmp = uprev+dt*(a1001*k1+a1004*k4+a1005*k5+a1006*k6+a1007*k7+a1008*k8+a1009*k9)
f(k10,tmp,p,t+c10*dt)
@. tmp = uprev+dt*(a1101*k1+a1104*k4+a1105*k5+a1106*k6+a1107*k7+a1108*k8+a1109*k9+a1110*k10)
@.. tmp = uprev+dt*(a1101*k1+a1104*k4+a1105*k5+a1106*k6+a1107*k7+a1108*k8+a1109*k9+a1110*k10)
f(k11,tmp,p,t+c11*dt)
@. tmp = uprev+dt*(a1201*k1+a1204*k4+a1205*k5+a1206*k6+a1207*k7+a1208*k8+a1209*k9+a1210*k10+a1211*k11)
@.. tmp = uprev+dt*(a1201*k1+a1204*k4+a1205*k5+a1206*k6+a1207*k7+a1208*k8+a1209*k9+a1210*k10+a1211*k11)
f(k12,tmp,p,t+dt)
@. kupdate = b1*k1+b6*k6+b7*k7+b8*k8+b9*k9+b10*k10+b11*k11+b12*k12
@. utmp = uprev + dt*update
@.. kupdate = b1*k1+b6*k6+b7*k7+b8*k8+b9*k9+b10*k10+b11*k11+b12*k12
@.. utmp = uprev + dt*kupdate
f(k13,utmp,p,t+dt)
@. tmp = uprev+dt*(a1401*k1+a1407*k7+a1408*k8+a1409*k9+a1410*k10+a1411*k11+a1412*k12+a1413*k13)
@.. tmp = uprev+dt*(a1401*k1+a1407*k7+a1408*k8+a1409*k9+a1410*k10+a1411*k11+a1412*k12+a1413*k13)
f(k14,tmp,p,t+c14*dt)
@. tmp = uprev+dt*(a1501*k1+a1506*k6+a1507*k7+a1508*k8+a1511*k11+a1512*k12+a1513*k13+a1514*k14)
@.. tmp = uprev+dt*(a1501*k1+a1506*k6+a1507*k7+a1508*k8+a1511*k11+a1512*k12+a1513*k13+a1514*k14)
f(k15,tmp,p,t+c15*dt)
@. tmp = uprev+dt*(a1601*k1+a1606*k6+a1607*k7+a1608*k8+a1609*k9+a1613*k13+a1614*k14+a1615*k15)
@.. tmp = uprev+dt*(a1601*k1+a1606*k6+a1607*k7+a1608*k8+a1609*k9+a1613*k13+a1614*k14+a1615*k15)
f(k16,tmp,p,t+c16*dt)
@. udiff= kupdate
@. bspl = k1 - udiff
@. k[3] = udiff - k13 - bspl
@. k[4] = (d401*k1+d406*k6+d407*k7+d408*k8+d409*k9+d410*k10+d411*k11+d412*k12+d413*k13+d414*k14+d415*k15+d416*k16)
@. k[5] = (d501*k1+d506*k6+d507*k7+d508*k8+d509*k9+d510*k10+d511*k11+d512*k12+d513*k13+d514*k14+d515*k15+d516*k16)
@. k[6] = (d601*k1+d606*k6+d607*k7+d608*k8+d609*k9+d610*k10+d611*k11+d612*k12+d613*k13+d614*k14+d615*k15+d616*k16)
@. k[7] = (d701*k1+d706*k6+d707*k7+d708*k8+d709*k9+d710*k10+d711*k11+d712*k12+d713*k13+d714*k14+d715*k15+d716*k16)
copyto!(udiff, kupdate)
@.. bspl = k1 - udiff
@.. k[3] = udiff - k13 - bspl
@.. k[4] = (d401*k1+d406*k6+d407*k7+d408*k8+d409*k9+d410*k10+d411*k11+d412*k12+d413*k13+d414*k14+d415*k15+d416*k16)
@.. k[5] = (d501*k1+d506*k6+d507*k7+d508*k8+d509*k9+d510*k10+d511*k11+d512*k12+d513*k13+d514*k14+d515*k15+d516*k16)
@.. k[6] = (d601*k1+d606*k6+d607*k7+d608*k8+d609*k9+d610*k10+d611*k11+d612*k12+d613*k13+d614*k14+d615*k15+d616*k16)
@.. k[7] = (d701*k1+d706*k6+d707*k7+d708*k8+d709*k9+d710*k10+d711*k11+d712*k12+d713*k13+d714*k14+d715*k15+d716*k16)
end
end
=#

#=
@muladd function DiffEqBase.addsteps!(k,t,uprev,u,dt,f,p,cache::DP8Cache,always_calc_begin = false,allow_calc_end = true,force_calc_end = false)
if length(k)<7 || always_calc_begin
@unpack c7,c8,c9,c10,c11,c6,c5,c4,c3,c2,b1,b6,b7,b8,b9,b10,b11,b12,a0201,a0301,a0302,a0401,a0403,a0501,a0503,a0504,a0601,a0604,a0605,a0701,a0704,a0705,a0706,a0801,a0804,a0805,a0806,a0807,a0901,a0904,a0905,a0906,a0907,a0908,a1001,a1004,a1005,a1006,a1007,a1008,a1009,a1101,a1104,a1105,a1106,a1107,a1108,a1109,a1110,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache.tab
Expand Down Expand Up @@ -167,3 +166,4 @@ end
end
end
end
=#
Loading