Skip to content

Commit

Permalink
analyticless wp stochastic
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 26, 2017
1 parent e56984f commit 8de789c
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 102 deletions.
5 changes: 5 additions & 0 deletions src/DiffEqDevTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import Base: length

const TIMESERIES_ERRORS = Set([:l2,:l∞,:L2,:L∞])
const DENSE_ERRORS = Set([:L2,:L∞])
const WEAK_TIMESERIES_ERRORS = Set([:weak_final])
const WEAK_DENSE_ERRORS = Set([:weak_L2])

parameterless_type(T::Type) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))

include("benchmark.jl")
include("convergence.jl")
Expand Down
224 changes: 130 additions & 94 deletions src/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function Shootout(prob,setups;appxsol=nothing,numruns=20,names=nothing,error_est
timeseries_errors = error_estimate TIMESERIES_ERRORS
dense_errors = error_estimate DENSE_ERRORS
if names == nothing
names = [string(typeof(setups[i][:alg])) for i=1:N]
names = [string(parameterless_type(setups[i][:alg])) for i=1:N]
end
for i in eachindex(setups)
sol = solve(prob,setups[i][:alg];timeseries_errors=timeseries_errors,
Expand Down Expand Up @@ -78,7 +78,7 @@ function ShootoutSet(probs,setups;probaux=nothing,numruns=20,
shootouts = Vector{Shootout}(N)
winners = Vector{String}(N)
if names == nothing
names = [string(typeof(setups[i][:alg])) for i=1:length(setups)]
names = [string(parameterless_type(setups[i][:alg])) for i=1:length(setups)]
end
if probaux == nothing
probaux = Vector{Dict{Symbol,Any}}(N)
Expand Down Expand Up @@ -147,7 +147,7 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
timeseries_errors = error_estimate TIMESERIES_ERRORS
dense_errors = error_estimate DENSE_ERRORS
for i in 1:N
t = 0.0
# Calculate errors and precompile
if dts == nothing
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
reltol=reltols[i],timeseries_errors=timeseries_errors,
Expand All @@ -165,122 +165,44 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
dense_errors = dense_errors) # Compile and get result
gc()
end
t = @elapsed for j in 1:numruns
if dts == nothing
solve(prob,alg,sol.u,sol.t,sol.k;kwargs...,
abstol=abstols[i],
reltol=reltols[i],
timeseries_errors=false,
dense_errors = false)
else
solve(prob,alg,sol.u,sol.t,sol.k;
kwargs...,abstol=abstols[i],
reltol=reltols[i],dt=dts[i],
timeseries_errors=false,
dense_errors = false)
end

end
t = t/numruns

if appxsol != nothing
errsol = calculate_errsol(prob,sol,appxsol)
errsol = appxtrue(prob,sol,appxsol)
errors[i] = mean(errsol.errors[error_estimate])
else
errors[i] = mean(sol.errors[error_estimate])
end
times[i] = t
end
return WorkPrecision(prob,abstols,reltols,errors,times,name,N)
end

# This will only do strong errors
function WorkPrecision(prob::Union{AbstractRODEProblem,AbstractSDEProblem},
alg,abstols,reltols,dts=nothing;
name=nothing,numruns=20,
appxsol=nothing,error_estimate=:final,kwargs...)
N = length(abstols)
errors = Vector{Float64}(N)
times = Vector{Float64}(N)
local_errors = Vector{Float64}(numruns)
if name == nothing
name = "WP-Alg"
end
timeseries_errors = error_estimate TIMESERIES_ERRORS
dense_errors = error_estimate DENSE_ERRORS
for i in 1:N
t = 0.0
if dts == nothing
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
reltol=reltols[i],timeseries_errors=timeseries_errors,
dense_errors = dense_errors) # Compile and get result
sol = solve(prob,alg,sol.u,sol.t;kwargs...,abstol=abstols[i],
reltol=reltols[i],timeseries_errors=timeseries_errors,
dense_errors = dense_errors) # Compile and get result
gc()
else
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
dense_errors = dense_errors) # Compile and get result
sol = solve(prob,alg,sol.u,sol.t;kwargs...,abstol=abstols[i],
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
dense_errors = dense_errors) # Compile and get result
gc()
end
t = @elapsed for j in 1:numruns
if dts == nothing
solve(prob,alg,sol.u,sol.t;kwargs...,
t= 0.0
for j in 1:numruns
t_tmp = @elapsed if dts == nothing
solve(prob,alg,sol.u,sol.t,sol.k;kwargs...,
abstol=abstols[i],
reltol=reltols[i],
timeseries_errors=false,
dense_errors = false)
else
solve(prob,alg,sol.u,sol.t;
solve(prob,alg,sol.u,sol.t,sol.k;
kwargs...,abstol=abstols[i],
reltol=reltols[i],dt=dts[i],
timeseries_errors=false,
dense_errors = false)
end
if appxsol != nothing
errsol = calculate_errsol(prob,sol,appxsol)
local_errors[j] = errsol.errors[error_estimate]
else
local_errors[j] = sol.errors[error_estimate]
end
t += t_tmp
gc()
end
t = t/numruns

errors[i] = mean(local_errors)
times[i] = t
times[i] = t/numruns
end
return WorkPrecision(prob,abstols,reltols,errors,times,name,N)
end

function calculate_errsol(prob,sol::AbstractODESolution,appxsol_setup::Dict)
true_sol = solve(prob,appxsol_setup[:alg];appxsol_setup...)
appxtrue(sol,true_sol)
end

function calculate_errsol(prob::AbstractSDEProblem,sol::AbstractRODESolution,appxsol_setup::Dict)
prob2 = SDEProblem(prob.f,prob.g,prob.u0,prob.tspan,noise=NoiseWrapper(sol.W))
true_sol = solve(prob2,appxsol_setup[:alg];appxsol_setup...)
appxtrue(sol,true_sol)
end

function calculate_errsol(prob,sol::AbstractODESolution,true_sol::AbstractTimeseriesSolution)
appxtrue(sol,true_sol)
end

function calculate_errsol(prob::MonteCarloProblem,sol,true_sol)
appxtrue(sol,true_sol)
end

function WorkPrecisionSet(prob,abstols,reltols,setups;numruns=20,
print_names=false,names=nothing,appxsol=nothing,kwargs...)
function WorkPrecisionSet(prob::AbstractODEProblem,abstols,reltols,setups;numruns=20,
print_names=false,names=nothing,appxsol=nothing,
test_dt=nothing,kwargs...)
N = length(setups)
wps = Vector{WorkPrecision}(N)
if names == nothing
names = [string(typeof(setups[i][:alg])) for i=1:length(setups)]
names = [string(parameterless_type(setups[i][:alg])) for i=1:length(setups)]
end
for i in 1:N
print_names && println(names[i])
Expand All @@ -297,6 +219,120 @@ function WorkPrecisionSet(prob,abstols,reltols,setups;numruns=20,
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names)
end

function WorkPrecisionSet(prob,abstols,reltols,setups,test_dt=nothing;
numruns=20,numruns_error = numruns,
print_names=false,names=nothing,appxsol_setup=nothing,
error_estimate=:final,kwargs...)

timeseries_errors = has_analytic(prob.f) && error_estimate TIMESERIES_ERRORS
weak_timeseries_errors = error_estimate WEAK_TIMESERIES_ERRORS
weak_dense_errors = error_estimate WEAK_DENSE_ERRORS
dense_errors = has_analytic(prob.f) && error_estimate DENSE_ERRORS
N = length(setups); M = length(abstols)
times = Array{Float64}(M,N)
tmp_solutions = Array{DESolution}(numruns,M,N)
if names == nothing
names = [string(parameterless_type(setups[i][:alg])) for i=1:length(setups)]
end
time_tmp = Vector{Float64}(numruns)

# First calculate all of the errors
@progress for i in 1:numruns
if !has_analytic(prob.f)
t = prob.tspan[1]:test_dt:prob.tspan[2]
brownian_values = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
brownian_values2 = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
np = NoiseGrid(t,brownian_values,brownian_values2)
_prob = SDEProblem(prob.f,prob.g,prob.u0,prob.tspan,
noise=np,
noise_rate_prototype=prob.noise_rate_prototype);
true_sol = solve(_prob,appxsol_setup[:alg];kwargs...,appxsol_setup...)
end

# Get a cache
if !haskey(setups[1],:dts)
sol = solve(prob,setups[1][:alg];
kwargs...,
abstol=abstols[1],
reltol=reltols[1],
timeseries_errors=false,
dense_errors = false)
else
sol = solve(prob,setups[1][:alg];
kwargs...,abstol=abstols[1],
reltol=reltols[1],dt=setups[1][:dts][1],
timeseries_errors=false,
dense_errors = false)
end

for j in 1:M, k in 1:N
if !haskey(setups[k],:dts)
sol = solve(prob,setups[k][:alg],sol.u,sol.t;
kwargs...,
abstol=abstols[j],
reltol=reltols[j],
timeseries_errors=timeseries_errors,
dense_errors = dense_errors)
else
sol = solve(prob,setups[k][:alg],sol.u,sol.t;
kwargs...,abstol=abstols[j],
reltol=reltols[j],dt=setups[k][:dts][j],
timeseries_errors=timeseries_errors,
dense_errors = dense_errors)
end
has_analytic(prob.f) ? err_sol = sol : err_sol = appxtrue(sol,true_sol)
tmp_solutions[i,j,k] = err_sol
end
end
tmp_solutions
_solutions_k = [[MonteCarloSolution(tmp_solutions[:,j,k],0.0,true) for j in 1:M] for k in 1:N]
solutions = [[calculate_monte_errors(sim;weak_timeseries_errors=weak_timeseries_errors,weak_dense_errors=weak_dense_errors) for sim in sol_k] for sol_k in _solutions_k]
errors = [[solutions[j][i].error_means[error_estimate] for i in 1:M] for j in 1:N]

# Now time it
for k in 1:N
# Get a cache and precompile
if !haskey(setups[1],:dts)
sol = solve(prob,setups[1][:alg];
kwargs...,
abstol=abstols[1],
reltol=reltols[1],
timeseries_errors=timeseries_errors,
dense_errors = dense_errors)
else
sol = solve(prob,setups[1][:alg];
kwargs...,abstol=abstols[1],
reltol=reltols[1],dt=setups[1][:dts][j],
timeseries_errors=timeseries_errors,
dense_errors = dense_errors)
end
gc()

for j in 1:M
for i in 1:numruns
time_tmp[i] = @elapsed if !haskey(setups[k],:dts)
sol = solve(prob,setups[k][:alg],sol.u,sol.t;
kwargs...,
abstol=abstols[j],
reltol=reltols[j],
timeseries_errors=false,
dense_errors = false)
else
sol = solve(prob,setups[k][:alg],sol.u,sol.t;
kwargs...,abstol=abstols[j],
reltol=reltols[j],dt=setups[k][:dts][j],
timeseries_errors=false,
dense_errors = false)
end
end
times[j,k] = mean(time_tmp)
gc()
end
end

wps = [WorkPrecision(prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1:N]
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names)
end
Base.length(wp::WorkPrecision) = wp.N
Base.size(wp::WorkPrecision) = length(wp)
Base.endof(wp::WorkPrecision) = length(wp)
Expand Down
12 changes: 4 additions & 8 deletions test/analyticless_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ prob = SDEProblem(f2,g2,[1.0;1.0],(0.0,10.0))

using StochasticDiffEq

dts = 1./2.^(8:-1:4)
test_dt = 1/2^10
sim1 = analyticless_test_convergence(dts,prob,SRIW1(),test_dt,numMonte=100)
sim2 = analyticless_test_convergence(dts,prob,RKMil(),test_dt,numMonte=100)
sim3 = analyticless_test_convergence(dts,prob,EM(),test_dt,numMonte=100)
@test sim1.𝒪est[:final]-1.5 < 0.2
@test sim2.𝒪est[:final]-1.0 < 0.2
@test sim3.𝒪est[:final]-0.5 < 0.2
dts = 1./2.^(6:-1:3)
test_dt = 1/2^8
sim1 = analyticless_test_convergence(dts,prob,SRIW1(),test_dt,numMonte=200)
@test sim1.𝒪est[:final]-1.5 < 0.3
25 changes: 25 additions & 0 deletions test/analyticless_stochastic_wp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using StochasticDiffEq, DiffEqDevTools, DiffEqProblemLibrary

prob = prob_sde_additivesystem
prob = SDEProblem(prob.f,prob.g,prob.u0,(0.0,1.0))

reltols = 1.0./10.0.^(1:5)
abstols = reltols#[0.0 for i in eachindex(reltols)]
setups = [Dict(:alg=>SRIW1())
Dict(:alg=>EM(),:dts=>1.0./5.0.^((1:length(reltols)) + 1))
Dict(:alg=>RKMil(),:dts=>1.0./5.0.^((1:length(reltols)) + 1))
Dict(:alg=>SRIW1(),:dts=>1.0./5.0.^((1:length(reltols)) + 1),:adaptive=>false)
Dict(:alg=>SRA1(),:dts=>1.0./5.0.^((1:length(reltols)) + 1),:adaptive=>false)
Dict(:alg=>SRA1())
]
names = ["SRIW1","EM","RKMil","SRIW1 Fixed","SRA1 Fixed","SRA1"]
test_dt = 0.1
wp = WorkPrecisionSet(prob,abstols,reltols,setups,test_dt;
numruns=5,names=names,error_estimate=:l2)

prob2 = SDEProblem((t,u,du)->prob.f(t,u,du),prob.g,prob.u0,(0.0,1.0))
test_dt = 2/10^5
appxsol_setup = Dict(:alg=>SRIW1(),:abstol=>1e-7,:reltol=>1e-7)
wp = WorkPrecisionSet(prob,abstols,reltols,setups,test_dt;
appxsol_setup = appxsol_setup,
numruns=5,names=names,error_estimate=:l2)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ using Base.Test
@time @testset "ODE AppxTrue Tests" begin include("ode_appxtrue_tests.jl") end
@time @testset "Analyticless Convergence Tests" begin include("analyticless_convergence_tests.jl") end
@time @testset "ODE Tableau Convergence Tests" begin include("ode_tableau_convergence_tests.jl") end ## Windows 32-bit fails on Butcher62 convergence test
@time @testset "Analyticless Stochastic WP" begin include("analyticless_stochastic_wp.jl") end

0 comments on commit 8de789c

Please sign in to comment.