Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Modify callback handling & utils #78

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions colt/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,10 @@ class UnionType: ...
from colt.lazy import Lazy
from colt.placeholder import Placeholder
from colt.registrable import Registrable
from colt.utils import remove_optional
from colt.types import ParamPath
from colt.utils import get_path_name, remove_optional

T = TypeVar("T")
ParamPath = Tuple[Union[int, str], ...]


def _get_path_name(path: ParamPath) -> str:
return ".".join(str(x) for x in path)


class ColtBuilder:
Expand Down Expand Up @@ -140,7 +136,7 @@ def _get_constructor_by_name(
constructor = cast(Type[T], DefaultRegistry.by_name(name, allow_to_import))
if constructor is None:
raise ConfigurationError(
f"[{ColtBuilder._get_path_name(path)}] type not found error: {name}"
f"[{ColtBuilder.get_path_name(path)}] type not found error: {name}"
)
return constructor

Expand Down Expand Up @@ -198,7 +194,7 @@ def _construct_args(

if not isinstance(args_config, (list, tuple)):
raise ConfigurationError(
f"[{_get_path_name(path)}] Arguments must be a list or tuple."
f"[{get_path_name(path)}] Arguments must be a list or tuple."
)
args: List[Any] = [
self._build(
Expand Down Expand Up @@ -245,7 +241,7 @@ def _build(
) -> Union[T, Any]:
if self._callback is not None:
with suppress(SkipCallback):
return self._callback.on_build(self, config, path, annotation)
config = self._callback.on_build(self, config, path, annotation)

if annotation is not None and isinstance(annotation, type):
annotation = remove_optional(annotation)
Expand All @@ -256,14 +252,14 @@ def _build(
if isinstance(config, Placeholder):
if annotation is not None and not config.match_type_hint(annotation):
raise ConfigurationError(
f"[{_get_path_name(path)}] Placeholder type mismatch: "
f"[{get_path_name(path)}] Placeholder type mismatch: "
f"expected {annotation}, got {config.type_hint}"
)
return config

if self._strict and annotation is None:
warnings.warn(
f"[{_get_path_name(path)}] Given config is not constructed because currently "
f"[{get_path_name(path)}] Given config is not constructed because currently "
"strict mode is enabled and the type annotation is not given.",
UserWarning,
)
Expand Down Expand Up @@ -323,7 +319,7 @@ def _build(

if isinstance(config, abc.Sized) and len(config) != len(args):
raise ConfigurationError(
f"[{_get_path_name(path)}] Tuple sizes of the given config and annotation "
f"[{get_path_name(path)}] Tuple sizes of the given config and annotation "
f"are mismatched: {config} / {args}"
)

Expand Down Expand Up @@ -353,7 +349,7 @@ def _build(
if origin == Literal:
if config not in args:
raise ConfigurationError(
f"[{_get_path_name(path)}] {config} is not a valid literal value."
f"[{get_path_name(path)}] {config} is not a valid literal value."
)
return config

Expand Down Expand Up @@ -394,13 +390,13 @@ def _build(
continue

trial_messages = [
f"[{_get_path_name(path)}] Trying to construct {annotation} with type {cls}:\n{e}\n{tb}"
f"[{get_path_name(path)}] Trying to construct {annotation} with type {cls}:\n{e}\n{tb}"
for cls, e, tb in trial_exceptions
]
raise ConfigurationError(
"\n\n"
+ "\n".join(textwrap.indent(msg, " ") for msg in trial_messages)
+ f"\n[{_get_path_name(path)}] Failed to construct object with type {annotation}"
+ f"\n[{get_path_name(path)}] Failed to construct object with type {annotation}"
)

if origin == Lazy:
Expand All @@ -410,12 +406,12 @@ def _build(
if isinstance(config, (list, set, tuple)):
if origin is not None and not isinstance(config, origin):
raise ConfigurationError(
f"[{_get_path_name(path)}] Type mismatch, expected type is "
f"[{get_path_name(path)}] Type mismatch, expected type is "
f"{origin}, but actual type is {type(config)}."
)
if isinstance(annotation, type) and not isinstance(config, annotation):
raise ConfigurationError(
f"[{_get_path_name(path)}] Type mismatch, expected type is "
f"[{get_path_name(path)}] Type mismatch, expected type is "
f"{annotation}, but actual type is {type(config)}."
)
cls = type(config)
Expand All @@ -433,12 +429,12 @@ def _build(
if not isinstance(config, abc.Mapping):
if origin is not None and not isinstance(config, origin):
raise ConfigurationError(
f"[{_get_path_name(path)}] Type mismatch, expected type is "
f"[{get_path_name(path)}] Type mismatch, expected type is "
f"{origin}, but actual type is {type(config)}."
)
if isinstance(annotation, type) and not isinstance(config, annotation):
raise ConfigurationError(
f"[{_get_path_name(path)}] Type mismatch, expected type is "
f"[{get_path_name(path)}] Type mismatch, expected type is "
f"{annotation}, but actual type is {type(config)}."
)
return config
Expand Down Expand Up @@ -469,7 +465,7 @@ def _build(
and not issubclass(constructor, annotation)
):
raise ConfigurationError(
f"[{_get_path_name(path)}] Type mismatch, expected type is "
f"[{get_path_name(path)}] Type mismatch, expected type is "
f"{annotation}, but actual type is {constructor}."
)

Expand All @@ -485,7 +481,7 @@ def _build(
except Exception as e:
if raise_configuration_error:
raise ConfigurationError(
f"[{_get_path_name(path)}] Failed to construct object with constructor {constructor}."
f"[{get_path_name(path)}] Failed to construct object with constructor {constructor}."
) from e
else:
raise
3 changes: 3 additions & 0 deletions colt/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Tuple, Union

ParamPath = Tuple[Union[int, str], ...]
6 changes: 6 additions & 0 deletions colt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import typing
from typing import Any, Dict, List, Optional, Sequence, Union, cast

from colt.types import ParamPath


def import_submodules(package_name: str) -> None:
"""
Expand Down Expand Up @@ -36,6 +38,10 @@ def import_modules(module_names: List[str]) -> None:
import_submodules(module_name)


def get_path_name(path: ParamPath) -> str:
return ".".join(str(x) for x in path)


def update_field(
obj: Union[Dict[Union[int, str], Any], List[Any]],
field: Union[int, str, Sequence[Union[int, str]]],
Expand Down