From 8f3fbbfa3f4a7f8e468a04e54b1f06f49a4f7233 Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 25 Sep 2024 15:41:07 +0900 Subject: [PATCH] modify callback handling & utils --- colt/builder.py | 38 +++++++++++++++++--------------------- colt/types.py | 3 +++ colt/utils.py | 6 ++++++ 3 files changed, 26 insertions(+), 21 deletions(-) create mode 100644 colt/types.py diff --git a/colt/builder.py b/colt/builder.py index 324701c..b2828e9 100644 --- a/colt/builder.py +++ b/colt/builder.py @@ -39,14 +39,10 @@ class UnionType: ... from colt.lazy import Lazy from colt.placeholder import Placeholder from colt.registrable import Registrable -from colt.utils import remove_optional +from colt.types import ParamPath +from colt.utils import get_path_name, remove_optional T = TypeVar("T") -ParamPath = Tuple[Union[int, str], ...] - - -def _get_path_name(path: ParamPath) -> str: - return ".".join(str(x) for x in path) class ColtBuilder: @@ -140,7 +136,7 @@ def _get_constructor_by_name( constructor = cast(Type[T], DefaultRegistry.by_name(name, allow_to_import)) if constructor is None: raise ConfigurationError( - f"[{ColtBuilder._get_path_name(path)}] type not found error: {name}" + f"[{ColtBuilder.get_path_name(path)}] type not found error: {name}" ) return constructor @@ -198,7 +194,7 @@ def _construct_args( if not isinstance(args_config, (list, tuple)): raise ConfigurationError( - f"[{_get_path_name(path)}] Arguments must be a list or tuple." + f"[{get_path_name(path)}] Arguments must be a list or tuple." ) args: List[Any] = [ self._build( @@ -245,7 +241,7 @@ def _build( ) -> Union[T, Any]: if self._callback is not None: with suppress(SkipCallback): - return self._callback.on_build(self, config, path, annotation) + config = self._callback.on_build(self, config, path, annotation) if annotation is not None and isinstance(annotation, type): annotation = remove_optional(annotation) @@ -256,14 +252,14 @@ def _build( if isinstance(config, Placeholder): if annotation is not None and not config.match_type_hint(annotation): raise ConfigurationError( - f"[{_get_path_name(path)}] Placeholder type mismatch: " + f"[{get_path_name(path)}] Placeholder type mismatch: " f"expected {annotation}, got {config.type_hint}" ) return config if self._strict and annotation is None: warnings.warn( - f"[{_get_path_name(path)}] Given config is not constructed because currently " + f"[{get_path_name(path)}] Given config is not constructed because currently " "strict mode is enabled and the type annotation is not given.", UserWarning, ) @@ -323,7 +319,7 @@ def _build( if isinstance(config, abc.Sized) and len(config) != len(args): raise ConfigurationError( - f"[{_get_path_name(path)}] Tuple sizes of the given config and annotation " + f"[{get_path_name(path)}] Tuple sizes of the given config and annotation " f"are mismatched: {config} / {args}" ) @@ -353,7 +349,7 @@ def _build( if origin == Literal: if config not in args: raise ConfigurationError( - f"[{_get_path_name(path)}] {config} is not a valid literal value." + f"[{get_path_name(path)}] {config} is not a valid literal value." ) return config @@ -394,13 +390,13 @@ def _build( continue trial_messages = [ - f"[{_get_path_name(path)}] Trying to construct {annotation} with type {cls}:\n{e}\n{tb}" + f"[{get_path_name(path)}] Trying to construct {annotation} with type {cls}:\n{e}\n{tb}" for cls, e, tb in trial_exceptions ] raise ConfigurationError( "\n\n" + "\n".join(textwrap.indent(msg, " ") for msg in trial_messages) - + f"\n[{_get_path_name(path)}] Failed to construct object with type {annotation}" + + f"\n[{get_path_name(path)}] Failed to construct object with type {annotation}" ) if origin == Lazy: @@ -410,12 +406,12 @@ def _build( if isinstance(config, (list, set, tuple)): if origin is not None and not isinstance(config, origin): raise ConfigurationError( - f"[{_get_path_name(path)}] Type mismatch, expected type is " + f"[{get_path_name(path)}] Type mismatch, expected type is " f"{origin}, but actual type is {type(config)}." ) if isinstance(annotation, type) and not isinstance(config, annotation): raise ConfigurationError( - f"[{_get_path_name(path)}] Type mismatch, expected type is " + f"[{get_path_name(path)}] Type mismatch, expected type is " f"{annotation}, but actual type is {type(config)}." ) cls = type(config) @@ -433,12 +429,12 @@ def _build( if not isinstance(config, abc.Mapping): if origin is not None and not isinstance(config, origin): raise ConfigurationError( - f"[{_get_path_name(path)}] Type mismatch, expected type is " + f"[{get_path_name(path)}] Type mismatch, expected type is " f"{origin}, but actual type is {type(config)}." ) if isinstance(annotation, type) and not isinstance(config, annotation): raise ConfigurationError( - f"[{_get_path_name(path)}] Type mismatch, expected type is " + f"[{get_path_name(path)}] Type mismatch, expected type is " f"{annotation}, but actual type is {type(config)}." ) return config @@ -469,7 +465,7 @@ def _build( and not issubclass(constructor, annotation) ): raise ConfigurationError( - f"[{_get_path_name(path)}] Type mismatch, expected type is " + f"[{get_path_name(path)}] Type mismatch, expected type is " f"{annotation}, but actual type is {constructor}." ) @@ -485,7 +481,7 @@ def _build( except Exception as e: if raise_configuration_error: raise ConfigurationError( - f"[{_get_path_name(path)}] Failed to construct object with constructor {constructor}." + f"[{get_path_name(path)}] Failed to construct object with constructor {constructor}." ) from e else: raise diff --git a/colt/types.py b/colt/types.py new file mode 100644 index 0000000..74aae42 --- /dev/null +++ b/colt/types.py @@ -0,0 +1,3 @@ +from typing import Tuple, Union + +ParamPath = Tuple[Union[int, str], ...] diff --git a/colt/utils.py b/colt/utils.py index 3c3095a..b047390 100644 --- a/colt/utils.py +++ b/colt/utils.py @@ -4,6 +4,8 @@ import typing from typing import Any, Dict, List, Optional, Sequence, Union, cast +from colt.types import ParamPath + def import_submodules(package_name: str) -> None: """ @@ -36,6 +38,10 @@ def import_modules(module_names: List[str]) -> None: import_submodules(module_name) +def get_path_name(path: ParamPath) -> str: + return ".".join(str(x) for x in path) + + def update_field( obj: Union[Dict[Union[int, str], Any], List[Any]], field: Union[int, str, Sequence[Union[int, str]]],