Skip to content

Commit

Permalink
Ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd committed Feb 21, 2025
1 parent 81415bf commit eb22e5b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
29 changes: 19 additions & 10 deletions python/dolfinx/nls/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import typing
from enum import Enum

from mpi4py import MPI
from petsc4py import PETSc
Expand Down Expand Up @@ -106,17 +107,21 @@ def copy_solution(self, x: PETSc.Vec): ...

def replace_solution(self, x: PETSc.Vec): ...

from enum import Enum
class SnesType(Enum):
default= 0
block=1
nest= 2

class SnesType(Enum):
default = 0
block = 1
nest = 2


def create_data_structures(a: typing.Union[list[list[fem.Form]], fem.Form], L: typing.Union[list[fem.Form], fem.Form], P: typing.Union[list[list[fem.Form]], list[fem.Form], fem.Form, None], snes_type: SnesType) -> tuple[PETSc.Mat, PETSc.Vec, PETSc.Vec, PETSc.Mat | None]:
def create_data_structures(
a: typing.Union[list[list[fem.Form]], fem.Form],
L: typing.Union[list[fem.Form], fem.Form],
P: typing.Union[list[list[fem.Form]], list[fem.Form], fem.Form, None],
snes_type: SnesType,
) -> tuple[PETSc.Mat, PETSc.Vec, PETSc.Vec, PETSc.Mat | None]:
"""Create data-structures used in PETSc NEST solvers
Args:
a: The compiled bi-linear form(s)
L: The compiled linear form(s)
Expand All @@ -143,8 +148,11 @@ def create_data_structures(a: typing.Union[list[list[fem.Form]], fem.Form], L: t
P = None if P is None else matrix_creator(P)
return A, x, b, P


class SNESSolver:
def __init__(self, problem: SNESProblemProtocol, snes_type: SnesType, options: dict | None = None):
def __init__(
self, problem: SNESProblemProtocol, snes_type: SnesType, options: dict | None = None
):
"""Initialize a PETSc-SNES solver
Args:
Expand All @@ -153,7 +161,9 @@ def __init__(self, problem: SNESProblemProtocol, snes_type: SnesType, options: d
"""
self.problem = problem
self.options = options if options is not None else {}
self._A, self._x, self._b, self._P = create_data_structures(problem.a, problem.L, problem.P, snes_type)
self._A, self._x, self._b, self._P = create_data_structures(
problem.a, problem.L, problem.P, snes_type
)
self.create_solver()
self.error_if_not_converged = True

Expand All @@ -172,7 +182,6 @@ def create_solver(self):
for key, v in self.options.items():
del opts[key]


def solve(self) -> tuple[int, int]:
"""Solve the problem and update the solution in the problem instance
Expand Down
24 changes: 18 additions & 6 deletions python/test/unit/fem/test_petsc_nonlinear_assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def blocked_solve():
J=J,
)
snes_options = {"snes_rtol": 1.0e-15, "snes_max_it": 10, "snes_monitor": None}
solver = dolfinx.nls.petsc.SNESSolver(problem, dolfinx.nls.petsc.SnesType.block, options=snes_options)
solver = dolfinx.nls.petsc.SNESSolver(
problem, dolfinx.nls.petsc.SnesType.block, options=snes_options
)
converged_reason, _ = solver.solve()
assert solver.krylov_solver.getConvergedReason() > 0
assert converged_reason > 0
Expand All @@ -359,7 +361,9 @@ def nested_solve():
}

problem = dolfinx.fem.petsc.NestSNESProblem(F, [u, p], bcs=bcs, J=J)
solver = dolfinx.nls.petsc.SNESSolver(problem, dolfinx.nls.petsc.SnesType.nest, options=snes_options)
solver = dolfinx.nls.petsc.SNESSolver(
problem, dolfinx.nls.petsc.SnesType.nest, options=snes_options
)
converged_reason, _ = solver.solve()
assert solver.krylov_solver.getConvergedReason() > 0
assert converged_reason > 0
Expand Down Expand Up @@ -403,7 +407,9 @@ def monolithic_solve():
U.sub(1).interpolate(initial_guess_p)

snes_prob = dolfinx.fem.petsc.SNESProblem(F, U, J=J, bcs=bcs)
snes_solver = dolfinx.nls.petsc.SNESSolver(snes_prob, dolfinx.nls.petsc.SnesType.default, options=snes_options)
snes_solver = dolfinx.nls.petsc.SNESSolver(
snes_prob, dolfinx.nls.petsc.SnesType.default, options=snes_options
)
snes_solver.solve()
xnorm = snes_prob.u.x.petsc_vec.norm()
return xnorm
Expand Down Expand Up @@ -501,7 +507,9 @@ def blocked():
"snes_monitor": None,
"ksp_type": "minres",
}
solver = dolfinx.nls.petsc.SNESSolver(problem, dolfinx.nls.petsc.SnesType.block, options=snes_options)
solver = dolfinx.nls.petsc.SNESSolver(
problem, dolfinx.nls.petsc.SnesType.block, options=snes_options
)
converged_reason, _ = solver.solve()
assert solver.krylov_solver.getConvergedReason() > 0
assert converged_reason > 0
Expand Down Expand Up @@ -529,7 +537,9 @@ def nested():
}

problem = dolfinx.fem.petsc.NestSNESProblem(F, [u, p], bcs=bcs, P=P)
solver = dolfinx.nls.petsc.SNESSolver(problem, dolfinx.nls.petsc.SnesType.nest, options=snes_options)
solver = dolfinx.nls.petsc.SNESSolver(
problem, dolfinx.nls.petsc.SnesType.nest, options=snes_options
)
converged_reason, _ = solver.solve()
assert solver.krylov_solver.getConvergedReason() > 0
assert converged_reason > 0
Expand Down Expand Up @@ -585,7 +595,9 @@ def monolithic():
"snes_monitor": None,
}
snes_prob = dolfinx.fem.petsc.SNESProblem(F, U, J=J, bcs=bcs, P=P)
snes_solver = dolfinx.nls.petsc.SNESSolver(snes_prob, dolfinx.nls.petsc.SnesType.default, options=snes_options)
snes_solver = dolfinx.nls.petsc.SNESSolver(
snes_prob, dolfinx.nls.petsc.SnesType.default, options=snes_options
)
snes_solver.solve()
xnorm = snes_prob.u.x.petsc_vec.norm()
Jnorm = snes_solver._A.norm()
Expand Down

0 comments on commit eb22e5b

Please # to comment.