From 50b9d60a9788798e68f010cd5a84d2b89626e582 Mon Sep 17 00:00:00 2001 From: John Myles White Date: Mon, 12 Oct 2015 09:11:02 -0700 Subject: [PATCH] Reset search direction in L-BFGS code to ensure descent Make use of the inner product check on search directions and gradients Reset pseudo-iteration counter to reset the approximate Hessian --- src/l_bfgs.jl | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/l_bfgs.jl b/src/l_bfgs.jl index 1c64d79c4..448ab5be8 100644 --- a/src/l_bfgs.jl +++ b/src/l_bfgs.jl @@ -11,15 +11,15 @@ function twoloop!(s::Vector, dx_history::Matrix, dgr_history::Matrix, m::Integer, - iteration::Integer, + pseudo_iteration::Integer, alpha::Vector, q::Vector) # Count number of parameters n = length(s) # Determine lower and upper bounds for loops - lower = iteration - m - upper = iteration - 1 + lower = pseudo_iteration - m + upper = pseudo_iteration - 1 # Copy gr into q for backward pass copy!(q, gr) @@ -102,6 +102,7 @@ end # Count the total number of iterations iteration = 0 + pseudo_iteration = 0 # Track calls to function and gradient f_calls, g_calls = 0, 0 @@ -155,13 +156,21 @@ end while !converged && iteration < iterations # Increment the number of steps we've had to perform iteration += 1 + pseudo_iteration += 1 # Determine the L-BFGS search direction - twoloop!(s, gr, rho, dx_history, dgr_history, m, iteration, + twoloop!(s, gr, rho, dx_history, dgr_history, m, pseudo_iteration, twoloop_alpha, twoloop_q) # Refresh the line search cache dphi0 = _dot(gr, s) + if dphi0 > 0.0 + pseudo_iteration = 1 + for i in 1:n + @inbounds s[i] = -gr[i] + end + dphi0 = _dot(gr, s) + end clear!(lsr) push!(lsr, zero(T), f_x, dphi0) @@ -197,9 +206,9 @@ end # TODO: Introduce a formal error? There was a warning here previously break end - dx_history[:, mod1(iteration, m)] = dx - dgr_history[:, mod1(iteration, m)] = dgr - rho[mod1(iteration, m)] = rho_iteration + dx_history[:, mod1(pseudo_iteration, m)] = dx + dgr_history[:, mod1(pseudo_iteration, m)] = dgr + rho[mod1(pseudo_iteration, m)] = rho_iteration x_converged, f_converged,