Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support for containers as element_type #443

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ConfigTypeError,
ConfigValueError,
OmegaConfBaseException,
ValidationError,
)

try:
Expand Down Expand Up @@ -491,7 +492,12 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:


def valid_value_annotation_type(type_: Any) -> bool:
return type_ is Any or is_primitive_type(type_) or is_structured_config(type_)
return (
type_ is Any
or is_primitive_type(type_)
or is_structured_config(type_)
or is_container_annotation(type_)
)


def _valid_dict_key_annotation_type(type_: Any) -> bool:
Expand Down Expand Up @@ -740,3 +746,73 @@ def is_generic_dict(type_: Any) -> bool:

def is_container_annotation(type_: Any) -> bool:
return is_list_annotation(type_) or is_dict_annotation(type_)


def is_container_assignment(cfg: Any) -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function name makes zero sense outside of the context it's used.

a cfg is not a container assignment.

element_type = cfg.__dict__["_metadata"].element_type
if is_container_annotation(element_type):
return True
else:
return False
Comment on lines +753 to +756
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not

element_type = cfg.__dict__["_metadata"].element_type
return is_container_annotation(element_type):

or even

return is_container_annotation(cfg.__dict__["_metadata"].element_type)

Also, why are you accessing the metadata through dict and not through cfg._metadata?



# Returns if value cannot be assigned to cfg element_type.
# We return False in case it's valid or because it should unwrap
# to dynamically check its contents.
# examples:
# 1. cfg.element_type = List[int], value=ListConfig(ref_type=List[str]) returns True
# 2. cfg.element_type = List[int], value=ListConfig(ref_type=List[int]) returns False
# 3. cfg.element_type = List[Any], value=ListConfig(ref_type=List[int]) returns False
# 4. cfg.element_type = List[int], value=ListConfig(ref_type=List[Any]) returns False
# needs to unwrap in order to check every values is a valid type
# 5. cfg.element_type = List[int], value=list([1, "invalid"]) returns False
def is_invalid_container_assignment(cfg: Any, value: Any) -> bool:
from omegaconf.basecontainer import BaseContainer

element_type = cfg.__dict__["_metadata"].element_type

if isinstance(value, BaseContainer):
item_ref_type = value._metadata.ref_type
if is_container_annotation(item_ref_type) and not is_legal_assignment(
element_type, item_ref_type
):
return True
return False


def should_unwrap(cfg: Any, value: Any) -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not clear why we are wrapping in the first place if we need to unwrap.
can we just not wrap in this scenario? we will be able to eliminate a lot of new logic.

from omegaconf.basecontainer import BaseContainer

if is_container_assignment(cfg) and isinstance(value, BaseContainer):
return True
else:
return False


def is_legal_assignment(dest_type: Any, src_type: Any) -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_legal_assignment in _utils sounds like a very generic function checking if assignment of a source type is legal to a dest type.
It only supports container types though.
It's also implemented in a verbose way.

The following code is not tested. just making a point:

# test direct assignment.
def is_legal_assignment(dest_type: Any, src_type: Any) -> bool:
  return issubclass(src_type, dest_type) or dest_type is Any

# assignment of a value to a typed list:
valid = is_legal_assignment(lst._metadata.element_type, value_type)

# assignment of a value to a typed dict
valid = is_legal_assignment(dct._metadata.element_type, value_type)

Your implementation for this includes many functions that are written in a very verbose way.
Seems like it can be much simpler.

is_legal = False
if is_list_annotation(dest_type) and is_list_annotation(src_type):
is_legal = is_legal_list_assignment(dest_type, src_type)
elif is_dict_annotation(dest_type) and is_dict_annotation(src_type):
is_legal = is_legal_dict_assignment(dest_type, src_type)
return is_legal
Comment on lines +793 to +798
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point on style: This is cleaner in this scenario:

    if is_list_annotation(dest_type) and is_list_annotation(src_type):
        return is_legal_list_assignment(dest_type, src_type)
    elif is_dict_annotation(dest_type) and is_dict_annotation(src_type):
        return is_legal_dict_assignment(dest_type, src_type)
    return False



def is_legal_dict_assignment(dest_type: Any, src_type: Any) -> bool:
key_values_pair_dest = get_dict_key_value_types(dest_type)
key_values_pair_src = get_dict_key_value_types(src_type)
return key_values_pair_dest == key_values_pair_src
Comment on lines +801 to +804
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just dest_type == src_type?



def is_legal_list_assignment(dest_type: Any, src_type: Any) -> Any:
element_types_dest = get_list_element_type(dest_type)
element_types_src = get_list_element_type(src_type)
return element_types_dest == element_types_src


def _raise_invalid_assignment(target_type: Any, value_type: Any, value: Any) -> None:
msg = (
f"Invalid type assigned : {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)
14 changes: 12 additions & 2 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind
from ._utils import (
ValueKind,
_get_value,
format_and_raise,
get_value_kind,
is_container_annotation,
)
from .errors import (
ConfigKeyError,
MissingMandatoryValue,
Expand Down Expand Up @@ -51,7 +57,11 @@ def __post_init__(self) -> None:
self.ref_type = Any
assert self.key_type is Any or isinstance(self.key_type, type)
if self.element_type is not None:
assert self.element_type is Any or isinstance(self.element_type, type)
assert (
self.element_type is Any
or isinstance(self.element_type, type)
or is_container_annotation(self.element_type)
)

if self.flags is None:
self.flags = {}
Expand Down
9 changes: 7 additions & 2 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_primitive_dict,
is_primitive_type,
is_structured_config,
should_unwrap,
)
from .base import Container, ContainerMetadata, DictKeyType, Node
from .errors import MissingMandatoryValue, ReadonlyConfigError, ValidationError
Expand Down Expand Up @@ -523,6 +524,8 @@ def _set_item_impl(self, key: Any, value: Any) -> None:
input_config = isinstance(value, Container)
target_node_ref = self._get_node(key)
special_value = value is None or value == "???"
expect_nested_container = is_container_annotation(self._metadata.element_type)
should_assign = not expect_nested_container and input_config

input_node = isinstance(value, ValueNode)
if isinstance(self.__dict__["_content"], dict):
Expand Down Expand Up @@ -560,6 +563,8 @@ def wrap(key: Any, val: Any) -> Node:
else:
is_optional = target._is_optional()
ref_type = target._metadata.ref_type
if input_config and should_unwrap(self, val):
val = val._value()
return _maybe_wrap(
ref_type=ref_type,
key=key,
Expand All @@ -582,7 +587,7 @@ def assign(value_key: Any, val: ValueNode) -> None:
# input is not node, can be primitive or config
if should_set_value:
self.__dict__["_content"][key]._set_value(value)
elif input_config:
elif should_assign:
assign(key, value)
else:
self.__dict__["_content"][key] = wrap(key, value)
Expand All @@ -592,7 +597,7 @@ def assign(value_key: Any, val: ValueNode) -> None:
elif not input_node and not target_node:
if should_set_value:
self.__dict__["_content"][key]._set_value(value)
elif input_config:
elif should_assign:
assign(key, value)
else:
self.__dict__["_content"][key] = wrap(key, value)
Expand Down
49 changes: 24 additions & 25 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ValueKind,
_get_value,
_is_interpolation,
_raise_invalid_assignment,
_valid_dict_key_annotation_type,
format_and_raise,
get_structured_config_data,
Expand Down Expand Up @@ -163,7 +164,7 @@ def _validate_get(self, key: Any, value: Any = None) -> None:
)

def _validate_set(self, key: Any, value: Any) -> None:
from omegaconf import OmegaConf
from omegaconf import AnyNode, OmegaConf

vk = get_value_kind(value)
if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION):
Expand All @@ -174,17 +175,30 @@ def _validate_set(self, key: Any, value: Any) -> None:

target = self._get_node(key) if key is not None else self

target_has_ref_type = isinstance(
target, DictConfig
) and target._metadata.ref_type not in (Any, dict)
is_valid_target = target is None or not target_has_ref_type
target_type = (
target._metadata.ref_type
if target is not None
else self._metadata.element_type
)
if target_type is Any:
return
target_has_ref_type = isinstance(target, DictConfig) and target_type not in (
Any,
dict,
)
value_type = OmegaConf.get_type(value)

input_container = (
is_structured_config(value_type) and self._metadata.element_type is not Any
)
Comment on lines +191 to +193
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not read as an input container to me.
What case are you detecting here?

Copy link
Contributor Author

@pereman2 pereman2 Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, it should be input_structured_config, what do you think?
If value_type is a structured and we have an element_type it should do issubclass(value_type, target_type) which didn't happen because this case was considered a "valid_value".

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._metadata.element_type can also be a primitive type. (int, str etc).
I think you should reorganize the logic here a bit.

input_any_node = isinstance(value, AnyNode) and target_type is not Any
is_valid_target = (
not target_has_ref_type and not input_container and not input_any_node
)

if is_valid_target:
return

target_type = target._metadata.ref_type # type: ignore
value_type = OmegaConf.get_type(value)

if is_dict(value_type) and is_dict(target_type):
return
if is_container_annotation(target_type) and not is_container_annotation(
Expand All @@ -199,7 +213,7 @@ def _validate_set(self, key: Any, value: Any) -> None:
and not issubclass(value_type, target_type)
)
if validation_error:
self._raise_invalid_value(value, value_type, target_type)
_raise_invalid_assignment(target_type, value_type, value)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a fix to facebookresearch/hydra#1322 by any chance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in dictconfig.py:629

else:  # pragma: no cover
    msg = f"Unsupported value type : {value}"
    raise ValidationError(msg)

should print type_str(type(value)). Should I fix it in this PR?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it's that simple. we can do it in a different pr (need a clean standalone repro first).


def _validate_merge(self, value: Any) -> None:
from omegaconf import OmegaConf
Expand Down Expand Up @@ -227,11 +241,7 @@ def _validate_merge(self, value: Any) -> None:
and not issubclass(src_obj_type, dest_obj_type)
)
if validation_error:
msg = (
f"Merge error : {type_str(src_obj_type)} is not a "
f"subclass of {type_str(dest_obj_type)}. value: {src}"
)
raise ValidationError(msg)
_raise_invalid_assignment(dest_obj_type, src_obj_type, value)

def _validate_non_optional(self, key: Any, value: Any) -> None:
from omegaconf import OmegaConf
Expand All @@ -253,17 +263,6 @@ def _validate_non_optional(self, key: Any, value: Any) -> None:
cause=ValidationError("field '$FULL_KEY' is not Optional"),
)

def _raise_invalid_value(
self, value: Any, value_type: Any, target_type: Any
) -> None:
assert value_type is not None
assert target_type is not None
msg = (
f"Invalid type assigned : {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)

def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
return self._s_validate_and_normalize_key(self._metadata.key_type, key)

Expand Down
38 changes: 25 additions & 13 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from ._utils import (
ValueKind,
_get_value,
_raise_invalid_assignment,
format_and_raise,
get_value_kind,
is_container_assignment,
is_int,
is_invalid_container_assignment,
is_primitive_list,
is_structured_config,
type_str,
should_unwrap,
valid_value_annotation_type,
)
from .base import Container, ContainerMetadata, Node
Expand Down Expand Up @@ -85,7 +88,7 @@ def _validate_get(self, key: Any, value: Any = None) -> None:
)

def _validate_set(self, key: Any, value: Any) -> None:
from omegaconf import OmegaConf
from omegaconf import OmegaConf, ValueNode

self._validate_get(key, value)

Expand All @@ -101,18 +104,23 @@ def _validate_set(self, key: Any, value: Any) -> None:
)

target_type = self._metadata.element_type
if target_type is Any:
return

value_type = OmegaConf.get_type(value)
if is_structured_config(target_type):
if (
target_type is not None
and value_type is not None
and not issubclass(value_type, target_type)
):
msg = (
f"Invalid type assigned : {type_str(value_type)} is not a "
f"subclass of {type_str(target_type)}. value: {value}"
)
raise ValidationError(msg)

needs_type_validation = is_structured_config(value_type) or isinstance(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the input is a primitive, say int. does it mean that it does not need type validation?

value, ValueNode
)
if (
needs_type_validation
and value_type is not None
and not issubclass(value_type, target_type)
) or (
is_container_assignment(self)
and is_invalid_container_assignment(self, value)
Comment on lines +119 to +121
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are only using is_invalid_container_assignment once, in ListConfig. is it not needed in DictConfig?

):
_raise_invalid_assignment(target_type, value_type, value)

def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
res = ListConfig(None)
Expand Down Expand Up @@ -252,12 +260,16 @@ def __setitem__(self, index: Union[int, slice], value: Any) -> None:
self._format_and_raise(key=index, value=value, cause=e)

def append(self, item: Any) -> None:

try:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

index = len(self)
self._validate_set(key=index, value=item)

if should_unwrap(self, item):
item = item._value()

node = _maybe_wrap(
ref_type=self.__dict__["_metadata"].element_type,
key=index,
Expand Down
18 changes: 18 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,21 @@ class InterpolationList:
@dataclass
class InterpolationDict:
dict: Dict[str, int] = II("optimization.lr")


@dataclass
class ContainerInDict:
dict_with_list: Dict[str, List[int]] = field(
default_factory=lambda: {"foo": [1, 2]}
)
dict_with_dict: Dict[str, Dict[str, int]] = field(
default_factory=lambda: {"foo": {"var": 1}}
)


@dataclass
class ContainerInList:
list_with_list: List[List[int]] = field(default_factory=lambda: [[0, 1], [2, 3]])
list_with_dict: List[Dict[str, int]] = field(
default_factory=lambda: [{"foo": 1, "foo2": 2}]
)
Comment on lines +197 to +212
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also cover cases where the type is a structured config?

Dict[str, List[User]] 
Dict[str, Dict[str, User]] 
List[List[User]]
List[Dict[str, User]]

Loading