Skip to content

Commit

Permalink
Merge pull request #716 from JuliaDiffEq/myb/fastbc
Browse files Browse the repository at this point in the history
Switch to the fast broadcast implementation
  • Loading branch information
ChrisRackauckas authored Apr 13, 2019
2 parents b521164 + a445431 commit 2d04bba
Show file tree
Hide file tree
Showing 45 changed files with 2,474 additions and 2,208 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ module OrdinaryDiffEq
set_abstol!, postamble!, last_step_failed,
isautodifferentiable

using DiffEqBase: check_error!, @def
using DiffEqBase: check_error!, @def, @..

macro tight_loop_macros(ex)
:($(esc(ex)))
Expand Down
14 changes: 7 additions & 7 deletions src/adams_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function ϕ_and_ϕstar!(cache, du, k)
β[i] = β[i-1] * ξ/ξ0
ξ += dts[i]
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -47,8 +47,8 @@ function ϕ_and_ϕstar!(cache::Union{VCABMConstantCache,VCABMCache}, du, k)
β[i] = β[i-1] * ξ/ξ0
ξ += dts[i]
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -64,8 +64,8 @@ function expand_ϕ_and_ϕstar!(cache, i)
ξ0 += dts[i]
β[i] = β[i-1] * ξ/ξ0
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@. ϕstar_n[i] = β[i] * ϕ_n[i]
@.. ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
@.. ϕstar_n[i] = β[i] * ϕ_n[i]
else
ϕ_n[i] = ϕ_n[i-1] - ϕstar_nm1[i-1]
ϕstar_n[i] = β[i] * ϕ_n[i]
Expand All @@ -78,7 +78,7 @@ function ϕ_np1!(cache, du_np1, k)
for i = 1:k
if i != 1
if typeof(cache) <: OrdinaryDiffEqMutableCache
@. ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
@.. ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
else
ϕ_np1[i] = ϕ_np1[i-1] - ϕstar_n[i-1]
end
Expand Down
4 changes: 2 additions & 2 deletions src/bdf_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function backward_diff!(cache::OrdinaryDiffEqMutableCache, D, D2, k, flag=true)
flag && copyto!(D[1], D2[1,1])
for i = 2:k
for j = 1:(k-i+1)
@. D2[i,j] = D2[i-1,j] - D2[i-1,j+1]
@.. D2[i,j] = D2[i-1,j] - D2[i-1,j+1]
end
flag && copyto!(D[i], D2[i,1])
end
Expand All @@ -52,7 +52,7 @@ function reinterpolate_history!(cache::OrdinaryDiffEqMutableCache, D, R, k)
fill!(tmp,zero(eltype(D[1])))
for j = 1:k
for k = 1:k
@. tmp += D[k] * R[k,j]
@.. tmp += D[k] * R[k,j]
end
D[j] .= tmp
fill!(tmp, zero(eltype(tmp)))
Expand Down
112 changes: 56 additions & 56 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,163 +389,163 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
#@.. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
(1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
# return @. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
# return @.. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
#@. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
@inbounds for i in eachindex(out)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
@.. out = (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
#@inbounds for i in eachindex(out)
# out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
@views @.. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
#@.. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
# return @. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
# return @.. k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
return k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
#@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
@inbounds for i in eachindex(out)
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
@.. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
#@inbounds for i in eachindex(out)
# out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
@inbounds for (j,i) in enumerate(idxs)
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
@views @.. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
#@.. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
#@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@views out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
#@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
@inbounds for i in eachindex(out)
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
@.. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
#@inbounds for i in eachindex(out)
# out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@inbounds for (j,i) in enumerate(idxs)
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

"""
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
"""
@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
#@. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
#@.. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#out = similar(y₀,axes(idxs))
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
#@views @.. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
@views out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
# @. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
for i in eachindex(out)
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
@.. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
#for i in eachindex(out)
# out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
for (j,i) in enumerate(idxs)
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
@views @.. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
#for (j,i) in enumerate(idxs)
# out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
#end
#out
end

######################## Linear Interpolants

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. Θm1*y₀ + Θ*y₁
@.. Θm1*y₀ + Θ*y₁
end

@muladd function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
@. Θm1*y₀[idxs] + Θ*y₁[idxs]
@.. Θm1*y₀[idxs] + Θ*y₁[idxs]
end

@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{0}})
Θm1 = (1-Θ)
@. out = Θm1*y₀ + Θ*y₁
@.. out = Θm1*y₀ + Θ*y₁
out
end

@muladd function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
Θm1 = (1-Θ)
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
@views @.. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
out
end

"""
Linear Interpolation
"""
function linear_interpolant(Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. (y₁ - y₀)/dt
@.. (y₁ - y₀)/dt
end

function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
@. (y₁[idxs] - y₀[idxs])/dt
@.. (y₁[idxs] - y₀[idxs])/dt
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs::Nothing,T::Type{Val{1}})
@. out = (y₁ - y₀)/dt
@.. out = (y₁ - y₀)/dt
out
end

function linear_interpolant!(out,Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
@views @. out = (y₁[idxs] - y₀[idxs])/dt
@views @.. out = (y₁[idxs] - y₀[idxs])/dt
out
end
50 changes: 25 additions & 25 deletions src/dense/high_order_rk_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
end
end

#=
@muladd function DiffEqBase.addsteps!(k,t,uprev,u,dt,f,p,cache::DP8Cache,always_calc_begin = false,allow_calc_end = true,force_calc_end = false)
if length(k)<7 || always_calc_begin
@unpack c7,c8,c9,c10,c11,c6,c5,c4,c3,c2,b1,b6,b7,b8,b9,b10,b11,b12,a0201,a0301,a0302,a0401,a0403,a0501,a0503,a0504,a0601,a0604,a0605,a0701,a0704,a0705,a0706,a0801,a0804,a0805,a0806,a0807,a0901,a0904,a0905,a0906,a0907,a0908,a1001,a1004,a1005,a1006,a1007,a1008,a1009,a1101,a1104,a1105,a1106,a1107,a1108,a1109,a1110,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache.tab
Expand All @@ -43,48 +42,48 @@ end
utmp = utilde
k = [cache.udiff,cache.bspl,cache.dense_tmp3,cache.dense_tmp4,cache.dense_tmp5,cache.dense_tmp6,cache.dense_tmp7]
f(k1,uprev,p,t)
@. tmp = uprev+dt*(a0201*k1)
@.. tmp = uprev+dt*(a0201*k1)
f(k2,tmp,p,t+c2*dt)
@. tmp = uprev+dt*(a0301*k1+a0302*k2)
@.. tmp = uprev+dt*(a0301*k1+a0302*k2)
f(k3,tmp,p,t+c3*dt)
@. tmp = uprev+dt*(a0401*k1+a0403*k3)
@.. tmp = uprev+dt*(a0401*k1+a0403*k3)
f(k4,tmp,p,t+c4*dt)
@. tmp = uprev+dt*(a0501*k1+a0503*k3+a0504*k4)
@.. tmp = uprev+dt*(a0501*k1+a0503*k3+a0504*k4)
f(k5,tmp,p,t+c5*dt)
@. tmp = uprev+dt*(a0601*k1+a0604*k4+a0605*k5)
@.. tmp = uprev+dt*(a0601*k1+a0604*k4+a0605*k5)
f(k6,tmp,p,t+c6*dt)
@. tmp = uprev+dt*(a0701*k1+a0704*k4+a0705*k5+a0706*k6)
@.. tmp = uprev+dt*(a0701*k1+a0704*k4+a0705*k5+a0706*k6)
f(k7,tmp,p,t+c7*dt)
@. tmp = uprev+dt*(a0801*k1+a0804*k4+a0805*k5+a0806*k6+a0807*k7)
@.. tmp = uprev+dt*(a0801*k1+a0804*k4+a0805*k5+a0806*k6+a0807*k7)
f(k8,tmp,p,t+c8*dt)
@. tmp = uprev+dt*(a0901*k1+a0904*k4+a0905*k5+a0906*k6+a0907*k7+a0908*k8)
@.. tmp = uprev+dt*(a0901*k1+a0904*k4+a0905*k5+a0906*k6+a0907*k7+a0908*k8)
f(k9,tmp,p,t+c9*dt)
@. tmp = uprev+dt*(a1001*k1+a1004*k4+a1005*k5+a1006*k6+a1007*k7+a1008*k8+a1009*k9)
@.. tmp = uprev+dt*(a1001*k1+a1004*k4+a1005*k5+a1006*k6+a1007*k7+a1008*k8+a1009*k9)
f(k10,tmp,p,t+c10*dt)
@. tmp = uprev+dt*(a1101*k1+a1104*k4+a1105*k5+a1106*k6+a1107*k7+a1108*k8+a1109*k9+a1110*k10)
@.. tmp = uprev+dt*(a1101*k1+a1104*k4+a1105*k5+a1106*k6+a1107*k7+a1108*k8+a1109*k9+a1110*k10)
f(k11,tmp,p,t+c11*dt)
@. tmp = uprev+dt*(a1201*k1+a1204*k4+a1205*k5+a1206*k6+a1207*k7+a1208*k8+a1209*k9+a1210*k10+a1211*k11)
@.. tmp = uprev+dt*(a1201*k1+a1204*k4+a1205*k5+a1206*k6+a1207*k7+a1208*k8+a1209*k9+a1210*k10+a1211*k11)
f(k12,tmp,p,t+dt)
@. kupdate = b1*k1+b6*k6+b7*k7+b8*k8+b9*k9+b10*k10+b11*k11+b12*k12
@. utmp = uprev + dt*update
@.. kupdate = b1*k1+b6*k6+b7*k7+b8*k8+b9*k9+b10*k10+b11*k11+b12*k12
@.. utmp = uprev + dt*kupdate
f(k13,utmp,p,t+dt)
@. tmp = uprev+dt*(a1401*k1+a1407*k7+a1408*k8+a1409*k9+a1410*k10+a1411*k11+a1412*k12+a1413*k13)
@.. tmp = uprev+dt*(a1401*k1+a1407*k7+a1408*k8+a1409*k9+a1410*k10+a1411*k11+a1412*k12+a1413*k13)
f(k14,tmp,p,t+c14*dt)
@. tmp = uprev+dt*(a1501*k1+a1506*k6+a1507*k7+a1508*k8+a1511*k11+a1512*k12+a1513*k13+a1514*k14)
@.. tmp = uprev+dt*(a1501*k1+a1506*k6+a1507*k7+a1508*k8+a1511*k11+a1512*k12+a1513*k13+a1514*k14)
f(k15,tmp,p,t+c15*dt)
@. tmp = uprev+dt*(a1601*k1+a1606*k6+a1607*k7+a1608*k8+a1609*k9+a1613*k13+a1614*k14+a1615*k15)
@.. tmp = uprev+dt*(a1601*k1+a1606*k6+a1607*k7+a1608*k8+a1609*k9+a1613*k13+a1614*k14+a1615*k15)
f(k16,tmp,p,t+c16*dt)
@. udiff= kupdate
@. bspl = k1 - udiff
@. k[3] = udiff - k13 - bspl
@. k[4] = (d401*k1+d406*k6+d407*k7+d408*k8+d409*k9+d410*k10+d411*k11+d412*k12+d413*k13+d414*k14+d415*k15+d416*k16)
@. k[5] = (d501*k1+d506*k6+d507*k7+d508*k8+d509*k9+d510*k10+d511*k11+d512*k12+d513*k13+d514*k14+d515*k15+d516*k16)
@. k[6] = (d601*k1+d606*k6+d607*k7+d608*k8+d609*k9+d610*k10+d611*k11+d612*k12+d613*k13+d614*k14+d615*k15+d616*k16)
@. k[7] = (d701*k1+d706*k6+d707*k7+d708*k8+d709*k9+d710*k10+d711*k11+d712*k12+d713*k13+d714*k14+d715*k15+d716*k16)
copyto!(udiff, kupdate)
@.. bspl = k1 - udiff
@.. k[3] = udiff - k13 - bspl
@.. k[4] = (d401*k1+d406*k6+d407*k7+d408*k8+d409*k9+d410*k10+d411*k11+d412*k12+d413*k13+d414*k14+d415*k15+d416*k16)
@.. k[5] = (d501*k1+d506*k6+d507*k7+d508*k8+d509*k9+d510*k10+d511*k11+d512*k12+d513*k13+d514*k14+d515*k15+d516*k16)
@.. k[6] = (d601*k1+d606*k6+d607*k7+d608*k8+d609*k9+d610*k10+d611*k11+d612*k12+d613*k13+d614*k14+d615*k15+d616*k16)
@.. k[7] = (d701*k1+d706*k6+d707*k7+d708*k8+d709*k9+d710*k10+d711*k11+d712*k12+d713*k13+d714*k14+d715*k15+d716*k16)
end
end
=#

#=
@muladd function DiffEqBase.addsteps!(k,t,uprev,u,dt,f,p,cache::DP8Cache,always_calc_begin = false,allow_calc_end = true,force_calc_end = false)
if length(k)<7 || always_calc_begin
@unpack c7,c8,c9,c10,c11,c6,c5,c4,c3,c2,b1,b6,b7,b8,b9,b10,b11,b12,a0201,a0301,a0302,a0401,a0403,a0501,a0503,a0504,a0601,a0604,a0605,a0701,a0704,a0705,a0706,a0801,a0804,a0805,a0806,a0807,a0901,a0904,a0905,a0906,a0907,a0908,a1001,a1004,a1005,a1006,a1007,a1008,a1009,a1101,a1104,a1105,a1106,a1107,a1108,a1109,a1110,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache.tab
Expand Down Expand Up @@ -167,3 +166,4 @@ end
end
end
end
=#
Loading

0 comments on commit 2d04bba

Please sign in to comment.