-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Support for containers as element_type #443
Changes from 18 commits
4b293ef
e798713
4ac1618
ce1c448
2ef94ca
36d455a
8770d10
3b8ede1
70e07b7
1ec53e0
e86d3e2
e16d4a6
481087b
31f42a8
a92792c
4953edc
ab7e495
0de6dbb
fa2c0a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
ConfigTypeError, | ||
ConfigValueError, | ||
OmegaConfBaseException, | ||
ValidationError, | ||
) | ||
|
||
try: | ||
|
@@ -491,7 +492,12 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]: | |
|
||
|
||
def valid_value_annotation_type(type_: Any) -> bool: | ||
return type_ is Any or is_primitive_type(type_) or is_structured_config(type_) | ||
return ( | ||
type_ is Any | ||
or is_primitive_type(type_) | ||
or is_structured_config(type_) | ||
or is_container_annotation(type_) | ||
) | ||
|
||
|
||
def _valid_dict_key_annotation_type(type_: Any) -> bool: | ||
|
@@ -740,3 +746,73 @@ def is_generic_dict(type_: Any) -> bool: | |
|
||
def is_container_annotation(type_: Any) -> bool: | ||
return is_list_annotation(type_) or is_dict_annotation(type_) | ||
|
||
|
||
def is_container_assignment(cfg: Any) -> bool: | ||
element_type = cfg.__dict__["_metadata"].element_type | ||
if is_container_annotation(element_type): | ||
return True | ||
else: | ||
return False | ||
Comment on lines
+753
to
+756
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not
or even
Also, why are you accessing the metadata through dict and not through |
||
|
||
|
||
# Returns if value cannot be assigned to cfg element_type. | ||
# We return False in case it's valid or because it should unwrap | ||
# to dynamically check its contents. | ||
# examples: | ||
# 1. cfg.element_type = List[int], value=ListConfig(ref_type=List[str]) returns True | ||
# 2. cfg.element_type = List[int], value=ListConfig(ref_type=List[int]) returns False | ||
# 3. cfg.element_type = List[Any], value=ListConfig(ref_type=List[int]) returns False | ||
# 4. cfg.element_type = List[int], value=ListConfig(ref_type=List[Any]) returns False | ||
# needs to unwrap in order to check every values is a valid type | ||
# 5. cfg.element_type = List[int], value=list([1, "invalid"]) returns False | ||
def is_invalid_container_assignment(cfg: Any, value: Any) -> bool: | ||
from omegaconf.basecontainer import BaseContainer | ||
|
||
element_type = cfg.__dict__["_metadata"].element_type | ||
|
||
if isinstance(value, BaseContainer): | ||
item_ref_type = value._metadata.ref_type | ||
if is_container_annotation(item_ref_type) and not is_legal_assignment( | ||
element_type, item_ref_type | ||
): | ||
return True | ||
return False | ||
|
||
|
||
def should_unwrap(cfg: Any, value: Any) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still not clear why we are wrapping in the first place if we need to unwrap. |
||
from omegaconf.basecontainer import BaseContainer | ||
|
||
if is_container_assignment(cfg) and isinstance(value, BaseContainer): | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def is_legal_assignment(dest_type: Any, src_type: Any) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is_legal_assignment in _utils sounds like a very generic function checking if assignment of a source type is legal to a dest type. The following code is not tested. just making a point: # test direct assignment.
def is_legal_assignment(dest_type: Any, src_type: Any) -> bool:
return issubclass(src_type, dest_type) or dest_type is Any
# assignment of a value to a typed list:
valid = is_legal_assignment(lst._metadata.element_type, value_type)
# assignment of a value to a typed dict
valid = is_legal_assignment(dct._metadata.element_type, value_type) Your implementation for this includes many functions that are written in a very verbose way. |
||
is_legal = False | ||
if is_list_annotation(dest_type) and is_list_annotation(src_type): | ||
is_legal = is_legal_list_assignment(dest_type, src_type) | ||
elif is_dict_annotation(dest_type) and is_dict_annotation(src_type): | ||
is_legal = is_legal_dict_assignment(dest_type, src_type) | ||
return is_legal | ||
Comment on lines
+793
to
+798
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Point on style: This is cleaner in this scenario: if is_list_annotation(dest_type) and is_list_annotation(src_type):
return is_legal_list_assignment(dest_type, src_type)
elif is_dict_annotation(dest_type) and is_dict_annotation(src_type):
return is_legal_dict_assignment(dest_type, src_type)
return False |
||
|
||
|
||
def is_legal_dict_assignment(dest_type: Any, src_type: Any) -> bool: | ||
key_values_pair_dest = get_dict_key_value_types(dest_type) | ||
key_values_pair_src = get_dict_key_value_types(src_type) | ||
return key_values_pair_dest == key_values_pair_src | ||
Comment on lines
+801
to
+804
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this just |
||
|
||
|
||
def is_legal_list_assignment(dest_type: Any, src_type: Any) -> Any: | ||
element_types_dest = get_list_element_type(dest_type) | ||
element_types_src = get_list_element_type(src_type) | ||
return element_types_dest == element_types_src | ||
|
||
|
||
def _raise_invalid_assignment(target_type: Any, value_type: Any, value: Any) -> None: | ||
msg = ( | ||
f"Invalid type assigned : {type_str(value_type)} is not a " | ||
f"subclass of {type_str(target_type)}. value: {value}" | ||
) | ||
raise ValidationError(msg) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
ValueKind, | ||
_get_value, | ||
_is_interpolation, | ||
_raise_invalid_assignment, | ||
_valid_dict_key_annotation_type, | ||
format_and_raise, | ||
get_structured_config_data, | ||
|
@@ -163,7 +164,7 @@ def _validate_get(self, key: Any, value: Any = None) -> None: | |
) | ||
|
||
def _validate_set(self, key: Any, value: Any) -> None: | ||
from omegaconf import OmegaConf | ||
from omegaconf import AnyNode, OmegaConf | ||
|
||
vk = get_value_kind(value) | ||
if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): | ||
|
@@ -174,17 +175,30 @@ def _validate_set(self, key: Any, value: Any) -> None: | |
|
||
target = self._get_node(key) if key is not None else self | ||
|
||
target_has_ref_type = isinstance( | ||
target, DictConfig | ||
) and target._metadata.ref_type not in (Any, dict) | ||
is_valid_target = target is None or not target_has_ref_type | ||
target_type = ( | ||
target._metadata.ref_type | ||
if target is not None | ||
else self._metadata.element_type | ||
) | ||
if target_type is Any: | ||
return | ||
target_has_ref_type = isinstance(target, DictConfig) and target_type not in ( | ||
Any, | ||
dict, | ||
) | ||
value_type = OmegaConf.get_type(value) | ||
|
||
input_container = ( | ||
is_structured_config(value_type) and self._metadata.element_type is not Any | ||
) | ||
Comment on lines
+191
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this does not read as an input container to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._metadata.element_type can also be a primitive type. (int, str etc). |
||
input_any_node = isinstance(value, AnyNode) and target_type is not Any | ||
is_valid_target = ( | ||
not target_has_ref_type and not input_container and not input_any_node | ||
) | ||
|
||
if is_valid_target: | ||
return | ||
|
||
target_type = target._metadata.ref_type # type: ignore | ||
value_type = OmegaConf.get_type(value) | ||
|
||
if is_dict(value_type) and is_dict(target_type): | ||
return | ||
if is_container_annotation(target_type) and not is_container_annotation( | ||
|
@@ -199,7 +213,7 @@ def _validate_set(self, key: Any, value: Any) -> None: | |
and not issubclass(value_type, target_type) | ||
) | ||
if validation_error: | ||
self._raise_invalid_value(value, value_type, target_type) | ||
_raise_invalid_assignment(target_type, value_type, value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a fix to facebookresearch/hydra#1322 by any chance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, in dictconfig.py:629 else: # pragma: no cover
msg = f"Unsupported value type : {value}"
raise ValidationError(msg) should print type_str(type(value)). Should I fix it in this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure it's that simple. we can do it in a different pr (need a clean standalone repro first). |
||
|
||
def _validate_merge(self, value: Any) -> None: | ||
from omegaconf import OmegaConf | ||
|
@@ -227,11 +241,7 @@ def _validate_merge(self, value: Any) -> None: | |
and not issubclass(src_obj_type, dest_obj_type) | ||
) | ||
if validation_error: | ||
msg = ( | ||
f"Merge error : {type_str(src_obj_type)} is not a " | ||
f"subclass of {type_str(dest_obj_type)}. value: {src}" | ||
) | ||
raise ValidationError(msg) | ||
_raise_invalid_assignment(dest_obj_type, src_obj_type, value) | ||
|
||
def _validate_non_optional(self, key: Any, value: Any) -> None: | ||
from omegaconf import OmegaConf | ||
|
@@ -253,17 +263,6 @@ def _validate_non_optional(self, key: Any, value: Any) -> None: | |
cause=ValidationError("field '$FULL_KEY' is not Optional"), | ||
) | ||
|
||
def _raise_invalid_value( | ||
self, value: Any, value_type: Any, target_type: Any | ||
) -> None: | ||
assert value_type is not None | ||
assert target_type is not None | ||
msg = ( | ||
f"Invalid type assigned : {type_str(value_type)} is not a " | ||
f"subclass of {type_str(target_type)}. value: {value}" | ||
) | ||
raise ValidationError(msg) | ||
|
||
def _validate_and_normalize_key(self, key: Any) -> DictKeyType: | ||
return self._s_validate_and_normalize_key(self._metadata.key_type, key) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,15 @@ | |
from ._utils import ( | ||
ValueKind, | ||
_get_value, | ||
_raise_invalid_assignment, | ||
format_and_raise, | ||
get_value_kind, | ||
is_container_assignment, | ||
is_int, | ||
is_invalid_container_assignment, | ||
is_primitive_list, | ||
is_structured_config, | ||
type_str, | ||
should_unwrap, | ||
valid_value_annotation_type, | ||
) | ||
from .base import Container, ContainerMetadata, Node | ||
|
@@ -85,7 +88,7 @@ def _validate_get(self, key: Any, value: Any = None) -> None: | |
) | ||
|
||
def _validate_set(self, key: Any, value: Any) -> None: | ||
from omegaconf import OmegaConf | ||
from omegaconf import OmegaConf, ValueNode | ||
|
||
self._validate_get(key, value) | ||
|
||
|
@@ -101,18 +104,23 @@ def _validate_set(self, key: Any, value: Any) -> None: | |
) | ||
|
||
target_type = self._metadata.element_type | ||
if target_type is Any: | ||
return | ||
|
||
value_type = OmegaConf.get_type(value) | ||
if is_structured_config(target_type): | ||
if ( | ||
target_type is not None | ||
and value_type is not None | ||
and not issubclass(value_type, target_type) | ||
): | ||
msg = ( | ||
f"Invalid type assigned : {type_str(value_type)} is not a " | ||
f"subclass of {type_str(target_type)}. value: {value}" | ||
) | ||
raise ValidationError(msg) | ||
|
||
needs_type_validation = is_structured_config(value_type) or isinstance( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the input is a primitive, say |
||
value, ValueNode | ||
) | ||
if ( | ||
needs_type_validation | ||
and value_type is not None | ||
and not issubclass(value_type, target_type) | ||
) or ( | ||
is_container_assignment(self) | ||
and is_invalid_container_assignment(self, value) | ||
Comment on lines
+119
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are only using is_invalid_container_assignment once, in ListConfig. is it not needed in DictConfig? |
||
): | ||
_raise_invalid_assignment(target_type, value_type, value) | ||
|
||
def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig": | ||
res = ListConfig(None) | ||
|
@@ -252,12 +260,16 @@ def __setitem__(self, index: Union[int, slice], value: Any) -> None: | |
self._format_and_raise(key=index, value=value, cause=e) | ||
|
||
def append(self, item: Any) -> None: | ||
|
||
try: | ||
from omegaconf.omegaconf import OmegaConf, _maybe_wrap | ||
|
||
index = len(self) | ||
self._validate_set(key=index, value=item) | ||
|
||
if should_unwrap(self, item): | ||
item = item._value() | ||
|
||
node = _maybe_wrap( | ||
ref_type=self.__dict__["_metadata"].element_type, | ||
key=index, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,3 +192,21 @@ class InterpolationList: | |
@dataclass | ||
class InterpolationDict: | ||
dict: Dict[str, int] = II("optimization.lr") | ||
|
||
|
||
@dataclass | ||
class ContainerInDict: | ||
dict_with_list: Dict[str, List[int]] = field( | ||
default_factory=lambda: {"foo": [1, 2]} | ||
) | ||
dict_with_dict: Dict[str, Dict[str, int]] = field( | ||
default_factory=lambda: {"foo": {"var": 1}} | ||
) | ||
|
||
|
||
@dataclass | ||
class ContainerInList: | ||
list_with_list: List[List[int]] = field(default_factory=lambda: [[0, 1], [2, 3]]) | ||
list_with_dict: List[Dict[str, int]] = field( | ||
default_factory=lambda: [{"foo": 1, "foo2": 2}] | ||
) | ||
Comment on lines
+197
to
+212
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also cover cases where the type is a structured config?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function name makes zero sense outside of the context it's used.
a cfg is not a container assignment.