-
-
Notifications
You must be signed in to change notification settings - Fork 83
/
OptimizationPolyalgorithms.jl
46 lines (37 loc) · 1.52 KB
/
OptimizationPolyalgorithms.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
module OptimizationPolyalgorithms
using Reexport
@reexport using Optimization
using Optimization.SciMLBase, OptimizationOptimJL, OptimizationOptimisers
struct PolyOpt end
SciMLBase.requiresgradient(opt::PolyOpt) = true
function SciMLBase.__solve(prob::OptimizationProblem,
opt::PolyOpt,
args...;
maxiters = nothing,
kwargs...)
loss, θ = x -> prob.f(x, prob.p), prob.u0
deterministic = first(loss(θ)) == first(loss(θ))
if (!isempty(args) || !deterministic) && maxiters === nothing
error("Automatic optimizer determination requires deterministic loss functions (and no data) or maxiters must be specified.")
end
if isempty(args) && deterministic && prob.lb === nothing && prob.ub === nothing
# If deterministic then ADAM -> finish with BFGS
if maxiters === nothing
res1 = Optimization.solve(prob, Optimisers.ADAM(0.01), args...; maxiters = 300,
kwargs...)
else
res1 = Optimization.solve(prob, Optimisers.ADAM(0.01), args...; maxiters,
kwargs...)
end
optprob2 = remake(prob, u0 = res1.u)
res1 = Optimization.solve(optprob2, BFGS(initial_stepnorm = 0.01), args...;
maxiters, kwargs...)
elseif isempty(args) && deterministic
res1 = Optimization.solve(prob, BFGS(initial_stepnorm = 0.01), args...; maxiters,
kwargs...)
else
res1 = Optimization.solve(prob, Optimisers.ADAM(0.1), args...; maxiters, kwargs...)
end
end
export PolyOpt
end