Skip to content

feat: GridapPETSc wrapper #541

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -42,6 +42,8 @@ SimpleNonlinearSolve = {path = "lib/SimpleNonlinearSolve"}
[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
Gridap = "56d4f2e9-7ea1-5844-9cf6-b9c51ca7ce8e"
GridapPETSc = "bcdc36c2-0c3e-11ea-095a-c9dadae499f1"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
@@ -53,9 +55,16 @@ SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[sources]
NonlinearSolveBase = {path = "lib/NonlinearSolveBase"}
NonlinearSolveFirstOrder = {path = "lib/NonlinearSolveFirstOrder"}
NonlinearSolveQuasiNewton = {path = "lib/NonlinearSolveQuasiNewton"}
NonlinearSolveSpectralMethods = {path = "lib/NonlinearSolveSpectralMethods"}

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveGridapPETScExt = ["Gridap", "GridapPETSc"]
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLSolversExt = "NLSolvers"
123 changes: 123 additions & 0 deletions ext/NonlinearSolveGridapPETScExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module NonlinearSolveGridapPETScExt

using Gridap: Gridap, Algebra
using GridapPETSc: GridapPETSc

using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolve: NonlinearSolve, GridapPETScSNES
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode

using ConcreteStructs: @concrete
using FastClosures: @closure

@concrete struct NonlinearSolveOperator <: Algebra.NonlinearOperator
f!
jac!
initial_guess_cache
resid_prototype
jacobian_prototype
end

function Algebra.residual!(b::AbstractVector, op::NonlinearSolveOperator, x::AbstractVector)
op.f!(b, x)
end

function Algebra.jacobian!(
A::AbstractMatrix, op::NonlinearSolveOperator, x::AbstractVector
)
op.jac!(A, x)
end

function Algebra.zero_initial_guess(op::NonlinearSolveOperator)
fill!(op.initial_guess_cache, 0)
return op.initial_guess_cache
end

function Algebra.allocate_residual(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.resid_prototype, 0)
return op.resid_prototype
end

function Algebra.allocate_jacobian(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.jacobian_prototype, 0)
return op.jacobian_prototype
end

# TODO: Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
function SciMLBase.__solve(
prob::NonlinearProblem, alg::GridapPETScSNES, args...;
abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
)

f_wrapped!, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
prob; alias_u0
)
T = eltype(u0)

abstol = NonlinearSolveBase.get_tolerance(abstol, T)
reltol = NonlinearSolveBase.get_tolerance(reltol, T)

nf = Ref{Int}(0)

f! = @closure (fx, x) -> begin
nf[] += 1
f_wrapped!(fx, x)
return fx
end

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; alg.autodiff
)
J_init = zeros(T, 1, 1)
else
jac!, J_init = NonlinearSolveBase.construct_extension_jac(
prob, alg, u0, resid; alg.autodiff, initial_jacobian = Val(true)
)
end

njac = Ref{Int}(-1)
jac_fn! = @closure (J, x) -> begin
njac[] += 1
jac!(J, x)
return J
end

nop = NonlinearSolveOperator(f!, jac_fn!, u0, resid, J_init)

petsc_args = [
"-snes_rtol", string(reltol), "-snes_atol", string(abstol),
"-snes_max_it", string(maxiters)
]
for (k, v) in pairs(alg.snes_options)
push!(petsc_args, "-$(k)")
push!(petsc_args, string(v))
end
show_trace isa Val{true} && push!(petsc_args, "-snes_monitor")

# TODO: We can reuse the cache returned from this function
sol_u = GridapPETSc.with(args = petsc_args) do
sol_u = copy(u0)
Algebra.solve!(sol_u, GridapPETSc.PETScNonlinearSolver(), nop)
return sol_u
end

f_wrapped!(resid, sol_u)
u_res = prob.u0 isa Number ? sol_u[1] : sol_u
resid_res = prob.u0 isa Number ? resid[1] : resid

objective = maximum(abs, resid)
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

12 changes: 9 additions & 3 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
@@ -17,6 +17,11 @@ function SciMLBase.__solve(
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
if !MPI.Initialized()
@warn "MPI not initialized. Initializing MPI with MPI.Init()." maxlog=1
MPI.Init()
end

# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
@@ -68,8 +73,10 @@ function SciMLBase.__solve(
PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

njac = Ref{Int}(-1)
if alg.autodiff !== missing || prob.f.jac !== nothing
# `missing` -> let PETSc compute the Jacobian using finite differences
if alg.autodiff !== missing
autodiff = alg.autodiff === missing ? nothing : alg.autodiff

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; autodiff
@@ -125,8 +132,7 @@ function SciMLBase.__solve(
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, original = snes,
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

3 changes: 3 additions & 0 deletions lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -45,3 +45,6 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -92,3 +92,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"]

[sources]
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
4 changes: 4 additions & 0 deletions lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
@@ -91,3 +91,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
@@ -47,3 +47,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme", "NaNMath"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
@@ -83,3 +83,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "BenchmarkTools", "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
@@ -62,3 +62,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/SCCNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -48,3 +48,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "NonlinearSolveBase", "NonlinearSolveFirstOrder", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveFirstOrder = {path = "../NonlinearSolveFirstOrder"}
3 changes: 3 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -95,3 +95,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -118,6 +118,6 @@ export NonlinearSolvePolyAlgorithm, FastShortcutNonlinearPolyalg, FastShortcutNL
# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK
export PETScSNES, GridapPETScSNES, CMINPACK

end
13 changes: 13 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
@@ -488,3 +488,16 @@ function PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing,
end
return PETScSNES(petsclib, mpi_comm, autodiff, kwargs)
end

# TODO: Docs
@concrete struct GridapPETScSNES <: AbstractNonlinearSolveAlgorithm
autodiff
snes_options
end

function GridapPETScSNES(; autodiff = nothing, kwargs...)
if Base.get_extension(@__MODULE__, :NonlinearSolveGridapPETScExt) === nothing
error("`GridapPETScSNES` requires `GridapPETSc.jl` to be loaded")
end
return GridapPETScSNES(autodiff, kwargs)
end
Loading