|
| 1 | +module NonlinearSolveLeastSquaresOptimExt |
| 2 | + |
| 3 | +using NonlinearSolve, SciMLBase |
| 4 | +import ConcreteStructs: @concrete |
| 5 | +import LeastSquaresOptim as LSO |
| 6 | + |
| 7 | +extension_loaded(::Val{:LeastSquaresOptim}) = true |
| 8 | + |
| 9 | +function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} |
| 10 | + ls = linsolve == :qr ? LSO.QR() : |
| 11 | + (linsolve == :cholesky ? LSO.Cholesky() : |
| 12 | + (linsolve == :lsmr ? LSO.LSMR() : nothing)) |
| 13 | + if alg == :lm |
| 14 | + return LSO.LevenbergMarquardt(ls) |
| 15 | + elseif alg == :dogleg |
| 16 | + return LSO.Dogleg(ls) |
| 17 | + else |
| 18 | + throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg")) |
| 19 | + end |
| 20 | +end |
| 21 | + |
| 22 | +@concrete struct LeastSquaresOptimCache |
| 23 | + prob |
| 24 | + alg |
| 25 | + allocated_prob |
| 26 | + kwargs |
| 27 | +end |
| 28 | + |
| 29 | +@concrete struct FunctionWrapper{iip} |
| 30 | + f |
| 31 | + p |
| 32 | +end |
| 33 | + |
| 34 | +(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p) |
| 35 | +(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p)) |
| 36 | + |
| 37 | +function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...; |
| 38 | + abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) |
| 39 | + iip = SciMLBase.isinplace(prob) |
| 40 | + |
| 41 | + f! = FunctionWrapper{iip}(prob.f, prob.p) |
| 42 | + g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p) |
| 43 | + |
| 44 | + lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!, |
| 45 | + J = prob.f.jac_prototype, alg.autodiff, |
| 46 | + output_length = length(prob.f.resid_prototype)) |
| 47 | + allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) |
| 48 | + |
| 49 | + return LeastSquaresOptimCache(prob, alg, allocated_prob, |
| 50 | + (; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose, |
| 51 | + kwargs...)) |
| 52 | +end |
| 53 | + |
| 54 | +function SciMLBase.solve!(cache::LeastSquaresOptimCache) |
| 55 | + res = LSO.optimize!(cache.allocated_prob; cache.kwargs...) |
| 56 | + maxiters = cache.kwargs[:iterations] |
| 57 | + retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success : |
| 58 | + (res.iterations ≥ maxiters ? ReturnCode.MaxIters : |
| 59 | + ReturnCode.ConvergenceFailure) |
| 60 | + stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations) |
| 61 | + return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2; |
| 62 | + retcode, original = res, stats) |
| 63 | +end |
| 64 | + |
| 65 | +end |
0 commit comments