Skip to content

Commit

Permalink
Add LineSearchTestCase (#177)
Browse files Browse the repository at this point in the history
* Add LineSearchTestCase

Also includes the failing case in PR#174.

Co-authored-by: Mateusz Baran <[email protected]>

* Add caching to all line search algorithms

* Add to docs

* Test caching for all algs

---------

Co-authored-by: Mateusz Baran <[email protected]>
  • Loading branch information
timholy and mateuszbaran authored Aug 5, 2024
1 parent ded667a commit 3259cd2
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
docs/build
docs/src/examples/generated
/docs/Manifest.toml
Manifest.toml
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LineSearches"
uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
version = "7.2.0"
version = "7.3.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -13,6 +13,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
DoubleFloats = "1"
NLSolversBase = "7"
NaNMath = "1"
Optim = "1"
OptimTestProblems = "2"
Parameters = "0.10, 0.11, 0.12"
julia = "1.6"

Expand Down
8 changes: 8 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ using LineSearches
```
to load the package.

## Debugging

If you suspect a method of suboptimal performance or find that your code errors,
create a [`LineSearchCache`](@ref) to record intermediate values for later
inspection and analysis. If you're using this via Optim.jl, configure it inside
the method, e.g., `Newton(linesearch=LineSearches.MoreThuente(; cache))`. The
value stored in the cache will reflect the final iteration of line search during
optimization.

## References

Expand Down
6 changes: 6 additions & 0 deletions docs/src/reference/linesearch.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ MoreThuente
Static
StrongWolfe
```

## Debugging

```@docs
LineSearchCache
```
27 changes: 23 additions & 4 deletions src/LineSearches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
__precompile__()

module LineSearches

using Printf
Expand All @@ -9,13 +7,14 @@ using Parameters, NaNMath
import NLSolversBase
import NLSolversBase: AbstractObjective

export LineSearchException
export LineSearchException, LineSearchCache

export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
export AbstractLineSearch, BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe

export InitialHagerZhang, InitialStatic, InitialPrevious,
InitialQuadratic, InitialConstantChange


function make_ϕ(df, x_new, x, s)
function ϕ(α)
# Move a distance of alpha in the direction of s
Expand Down Expand Up @@ -91,6 +90,26 @@ end

include("types.jl")

# The following don't extend `empty!` and `push!` because we want implementations for `nothing`
# and that would be piracy
emptycache!(cache::LineSearchCache) = begin
empty!(cache.alphas)
empty!(cache.values)
empty!(cache.slopes)
end
emptycache!(::Nothing) = nothing
pushcache!(cache::LineSearchCache, α, val, slope) = begin
push!(cache.alphas, α)
push!(cache.values, val)
push!(cache.slopes, slope)
end
pushcache!(cache::LineSearchCache, α, val) = begin
push!(cache.alphas, α)
push!(cache.values, val)
end
pushcache!(::Nothing, α, val, slope) = nothing
pushcache!(::Nothing, α, val) = nothing

# Line Search Methods
include("backtracking.jl")
include("strongwolfe.jl")
Expand Down
10 changes: 8 additions & 2 deletions src/backtracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ there exists a factor ρ = ρ(c₁) such that α' ≦ ρ α.
This is a modification of the algorithm described in Nocedal Wright (2nd ed), Sec. 3.5.
"""
@with_kw struct BackTracking{TF, TI}
@with_kw struct BackTracking{TF, TI} <: AbstractLineSearch
c_1::TF = 1e-4
ρ_hi::TF = 0.5
ρ_lo::TF = 0.1
iterations::TI = 1_000
order::TI = 3
maxstep::TF = Inf
cache::Union{Nothing,LineSearchCache{TF}} = nothing
end
BackTracking{TF}(args...; kwargs...) where TF = BackTracking{TF,Int}(args...; kwargs...)

Expand All @@ -37,7 +38,9 @@ end

# TODO: Should we deprecate the interface that only uses the ϕ argument?
function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where
@unpack c_1, ρ_hi, ρ_lo, iterations, order = ls
@unpack c_1, ρ_hi, ρ_lo, iterations, order, cache = ls
emptycache!(cache)
pushcache!(cache, 0, ϕ_0, dϕ_0) # backtracking doesn't use the slope except here

iterfinitemax = -log2(eps(real(Tα)))

Expand Down Expand Up @@ -68,6 +71,8 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα

ϕx_1 = ϕ(α_2)
end
pushcache!(cache, αinitial, ϕx_1)
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)

# Backtrack until we satisfy sufficient decrease condition
while ϕx_1 > ϕ_0 + c_1 * α_2 * dϕ_0
Expand Down Expand Up @@ -112,6 +117,7 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα

# Evaluate f(x) at proposed position
ϕx_0, ϕx_1 = ϕx_1, ϕ(α_2)
pushcache!(cache, α_2, ϕx_1)
end

return α_2, ϕx_1
Expand Down
21 changes: 14 additions & 7 deletions src/hagerzhang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Conjugate gradient line search implementation from:
conjugate gradient method with guaranteed descent. ACM
Transactions on Mathematical Software 32: 113–137.
"""
@with_kw struct HagerZhang{T, Tm}
@with_kw struct HagerZhang{T, Tm} <: AbstractLineSearch
delta::T = DEFAULTDELTA # c_1 Wolfe sufficient decrease condition
sigma::T = DEFAULTSIGMA # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
alphamax::T = Inf
Expand All @@ -91,6 +91,7 @@ Conjugate gradient line search implementation from:
psi3::T = 0.1
display::Int = 0
mayterminate::Tm = Ref{Bool}(false)
cache::Union{Nothing,LineSearchCache{T}} = nothing
end
HagerZhang{T}(args...; kwargs...) where T = HagerZhang{T, Base.RefValue{Bool}}(args...; kwargs...)

Expand All @@ -109,9 +110,11 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
phi_0::Real,
dphi_0::Real) where T # Should c and phi_0 be same type?
@unpack delta, sigma, alphamax, rho, epsilon, gamma,
linesearchmax, psi3, display, mayterminate = ls
linesearchmax, psi3, display, mayterminate, cache = ls
emptycache!(cache)

zeroT = convert(T, 0)
pushcache!(cache, zeroT, phi_0, dphi_0)
if !(isfinite(phi_0) && isfinite(dphi_0))
throw(LineSearchException("Value and slope at step length = 0 must be finite.", T(0)))
end
Expand All @@ -124,9 +127,13 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
# Prevent values of x_new = x+αs that are likely to make
# ϕ(x_new) infinite
iterfinitemax::Int = ceil(Int, -log2(eps(T)))
alphas = [zeroT] # for bisection
values = [phi_0]
slopes = [dphi_0]
if cache !== nothing
@unpack alphas, values, slopes = cache
else
alphas = [zeroT] # for bisection
values = [phi_0]
slopes = [dphi_0]
end
if display & LINESEARCH > 0
println("New linesearch")
end
Expand Down Expand Up @@ -203,10 +210,10 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
else
# We'll still going downhill, expand the interval and try again.
# Reaching this branch means that dphi_c < 0 and phi_c <= phi_0 + ϵ_k
# So cold = c has a lower objective than phi_0 up to epsilon.
# So cold = c has a lower objective than phi_0 up to epsilon.
# This makes it a viable step to return if bracketing fails.

# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
# Going back to the loop with c = cold will only result in infinite cycling.
# So returning (cold, phi_cold) and exiting the line search is the best move.
cold = c
Expand Down
10 changes: 8 additions & 2 deletions src/morethuente.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ The line search implementation from:
Line search algorithms with guaranteed sufficient decrease.
ACM Transactions on Mathematical Software (TOMS) 20.3 (1994): 286-307.
"""
@with_kw struct MoreThuente{T}
@with_kw struct MoreThuente{T} <: AbstractLineSearch
f_tol::T = 1e-4 # c_1 Wolfe sufficient decrease condition
gtol::T = 0.9 # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
x_tol::T = 1e-8
alphamin::T = 1e-16
alphamax::T = 65536.0
maxfev::Int = 100
cache::Union{Nothing,LineSearchCache{T}} = nothing
end

function (ls::MoreThuente)(df::AbstractObjective, x::AbstractArray{T},
Expand All @@ -161,13 +162,15 @@ function (ls::MoreThuente)(ϕdϕ,
alpha::T,
ϕ_0,
dϕ_0) where T
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev = ls
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev, cache = ls
emptycache!(cache)

iterfinitemax = -log2(eps(T))
info = 0
info_cstep = 1 # Info from step

zeroT = convert(T, 0)
pushcache!(cache, zeroT, ϕ_0, dϕ_0)

#
# Check the input parameters for errors.
Expand Down Expand Up @@ -236,7 +239,9 @@ function (ls::MoreThuente)(ϕdϕ,
# Make stmax = (3/2)*alpha < 2alpha in the first iteration below
stx = (convert(T, 7)/8)*alpha
end
pushcache!(cache, alpha, f, dg)
# END: Ensure that the initial step provides finite function values
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)

while true
#
Expand Down Expand Up @@ -282,6 +287,7 @@ function (ls::MoreThuente)(ϕdϕ,
# and compute the directional derivative.
#
f, dg = ϕdϕ(alpha)
pushcache!(cache, alpha, f, dg)
nfev += 1 # This includes calls to f() and g!()

if isapprox(dg, 0, atol=eps(T)) # Should add atol value to MoreThuente
Expand Down
2 changes: 1 addition & 1 deletion src/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
`Static` is intended for methods with well-scaled updates; i.e. Newton, on well-behaved problems.
"""
struct Static end
struct Static <: AbstractLineSearch end

function (ls::Static)(df::AbstractObjective, x, s, α, x_new = similar(x), ϕ_0 = nothing, dϕ_0 = nothing)
ϕ = make_ϕ(df, x_new, x, s)
Expand Down
22 changes: 18 additions & 4 deletions src/strongwolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ use `MoreThuente`, `HagerZhang` or `BackTracking`.
* `c_2 = 0.9` : second (strong) Wolfe condition
* `ρ = 2.0` : bracket growth
"""
@with_kw struct StrongWolfe{T}
@with_kw struct StrongWolfe{T} <: AbstractLineSearch
c_1::T = 1e-4
c_2::T = 0.9
ρ::T = 2.0
cache::Union{Nothing,LineSearchCache{T}} = nothing
end

"""
Expand Down Expand Up @@ -49,9 +50,11 @@ Both `alpha` and `ϕ(alpha)` are returned.
"""
function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
alpha0::T, ϕ_0, dϕ_0) where T<:Real
@unpack c_1, c_2, ρ = ls
@unpack c_1, c_2, ρ, cache = ls
emptycache!(cache)

zeroT = convert(T, 0)
pushcache!(cache, zeroT, ϕ_0, dϕ_0)

# Step-sizes
a_0 = zeroT
Expand All @@ -71,17 +74,21 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,

while a_i < a_max
ϕ_a_i = ϕ(a_i)
pushcache!(cache, a_i, ϕ_a_i)

# Test Wolfe conditions
if (ϕ_a_i > ϕ_0 + c_1 * a_i * dϕ_0) ||
(ϕ_a_i >= ϕ_a_iminus1 && i > 1)
a_star = zoom(a_iminus1, a_i,
dϕ_0, ϕ_0,
ϕ, dϕ, ϕdϕ)
ϕ, dϕ, ϕdϕ, cache)
return a_star, ϕ(a_star)
end

dϕ_a_i = (a_i)
if cache !== nothing
push!(cache.slopes, dϕ_a_i)
end

# Check condition 2
if abs(dϕ_a_i) <= -c_2 * dϕ_0
Expand All @@ -91,7 +98,7 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
# Check condition 3
if dϕ_a_i >= zeroT # FIXME untested!
a_star = zoom(a_i, a_iminus1,
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ)
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ, cache)
return a_star, ϕ(a_star)
end

Expand All @@ -117,6 +124,7 @@ function zoom(a_lo::T,
ϕ,
dϕ,
ϕdϕ,
cache,
c_1::Real = convert(T, 1)/10^4,
c_2::Real = convert(T, 9)/10) where T

Expand All @@ -133,8 +141,10 @@ function zoom(a_lo::T,
iteration += 1

ϕ_a_lo, ϕprime_a_lo = ϕdϕ(a_lo)
pushcache!(cache, a_lo, ϕ_a_lo, ϕprime_a_lo)

ϕ_a_hi, ϕprime_a_hi = ϕdϕ(a_hi)
pushcache!(cache, a_hi, ϕ_a_hi, ϕprime_a_hi)

# Interpolate a_j
if a_lo < a_hi
Expand All @@ -150,6 +160,7 @@ function zoom(a_lo::T,

# Evaluate ϕ(a_j)
ϕ_a_j = ϕ(a_j)
pushcache!(cache, a_j, ϕ_a_j)

# Check Armijo
if (ϕ_a_j > ϕ_0 + c_1 * a_j * dϕ_0) ||
Expand All @@ -158,6 +169,9 @@ function zoom(a_lo::T,
else
# Evaluate ϕprime(a_j)
ϕprime_a_j = (a_j)
if cache !== nothing
push!(cache.slopes, ϕprime_a_j)
end

if abs(ϕprime_a_j) <= -c_2 * dϕ_0
return a_j
Expand Down
36 changes: 36 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,39 @@ mutable struct LineSearchException{T<:Real} <: Exception
message::AbstractString
alpha::T
end

abstract type AbstractLineSearch end

# For debugging
struct LineSearchCache{T}
alphas::Vector{T}
values::Vector{T}
slopes::Vector{T}
end
"""
cache = LineSearchCache{T}()
Initialize an empty cache for storing intermediate results during line search.
The `α`, `ϕ(α)`, and possibly `dϕ(α)` values computed during line search are
available in `cache.alphas`, `cache.values`, and `cache.slopes`, respectively.
# Example
```jldoctest
julia> ϕ(x) = (x - π)^4; dϕ(x) = 4*(x-π)^3;
julia> cache = LineSearchCache{Float64}();
julia> ls = BackTracking(; cache);
julia> ls(ϕ, 10.0, ϕ(0), dϕ(0))
(1.8481462933284658, 2.7989406670901373)
julia> cache
LineSearchCache{Float64}([0.0, 10.0, 1.8481462933284658], [97.40909103400242, 2212.550050116452, 2.7989406670901373], [-124.02510672119926])
```
Because `BackTracking` doesn't use derivatives except at `α=0`, only the initial slope was stored in the cache.
Other methods may store all three.
"""
LineSearchCache{T}() where T = LineSearchCache{T}(T[], T[], T[])
3 changes: 0 additions & 3 deletions test/REQUIRE

This file was deleted.

Loading

2 comments on commit 3259cd2

@pkofod
Copy link
Member

@pkofod pkofod commented on 3259cd2 Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/112473

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v7.3.0 -m "<description of version>" 3259cd240144b96a5a3a309ea96dfb19181058b2
git push origin v7.3.0

Please sign in to comment.