diff --git a/FOX/armc/param_mapping.py b/FOX/armc/param_mapping.py index 7b47621e..2baabd61 100755 --- a/FOX/armc/param_mapping.py +++ b/FOX/armc/param_mapping.py @@ -715,5 +715,9 @@ def apply_constraints(self, idx: Tup3, value: float, param_idx: int) -> None | C if ret is not None: self.param[:] = param_backup + if exclude is not None: + ret.atoms = self.param.loc[key, :].index.difference(exclude) + else: + ret.atoms = self.param.loc[key, :].index return ret return None diff --git a/FOX/functions/charge_utils.py b/FOX/functions/charge_utils.py index 79345d30..6f95935d 100644 --- a/FOX/functions/charge_utils.py +++ b/FOX/functions/charge_utils.py @@ -16,7 +16,7 @@ from itertools import chain from collections.abc import Hashable, Collection, Container -from typing import SupportsFloat, TypeVar, Generic +from typing import SupportsFloat, TypeVar, Any import numpy as np import pandas as pd @@ -33,38 +33,55 @@ class _StateDict(TypedDict): """A dictionary representing the keyword-only arguments of :exc:`ChargeError`.""" - reference: None | float - value: None | float - tol: None | float + reference: float + value: float + tol: float + atoms: None | Collection[str] -class ChargeError(ValueError, Generic[T]): +class ChargeError(ValueError): """A :exc:`ValueError` subclass for charge-related errors.""" - __slots__ = ('__weakref__', 'reference', 'value', 'tol') + __slots__ = ('__weakref__', 'reference', 'value', 'tol', 'atoms') - reference: None | float - value: None | float - tol: None | float - args: tuple[T, ...] + reference: float + value: float + tol: float + atoms: None | Collection[str] def __init__( self, - *args: T, - reference: None | SupportsFloat = None, - value: None | SupportsFloat = None, - tol: None | SupportsFloat = 0.001, + *args: Any, + reference: SupportsFloat = 0.0, + value: SupportsFloat = 0.0, + tol: SupportsFloat = 0.001, + atoms: None | Collection[str] = None, ) -> None: """Initialize an instance.""" super().__init__(*args) - self.reference = float(reference) if reference is not None else None - self.value = float(value) if value is not None else None - self.tol = float(tol) if tol is not None else None - - def __reduce__(self: ST) -> tuple[type[ST], tuple[T, ...], _StateDict]: + self.reference = float(reference) + self.value = float(value) + self.tol = float(tol) + self.atoms = atoms + + def __str__(self) -> str: + if self.atoms is None: + return ( + f"Failed to conserve the net charge: ref = {self.reference:.4f}; " + f"{self.value:.4f} != ref" + ) + else: + atoms = "{" + ", ".join(sorted(self.atoms)) + "}" + return ( + f"Failed to conserve the net charge when moving atoms {atoms}: " + f"ref = {self.reference:.4f}; {self.value:.4f} != ref" + ) + + def __reduce__(self: ST) -> tuple[type[ST], tuple[Any, ...], _StateDict]: """Helper for :mod:`pickle`.""" cls = type(self) - kwargs = _StateDict(reference=self.reference, value=self.value, tol=self.tol) + kwargs = _StateDict(reference=self.reference, value=self.value, + tol=self.tol, atoms=self.atoms) return cls, self.args, kwargs def __setstate__(self, state: _StateDict) -> None: @@ -310,8 +327,4 @@ def _check_net_charge( if not condition: return - - raise ChargeError( - f"Failed to conserve the net charge: ref = {net_charge:.4f}; {net_charge_new:.4f} != ref", - reference=net_charge, value=net_charge_new, tol=tolerance - ) + raise ChargeError(reference=net_charge, value=net_charge_new, tol=tolerance)