Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

make object var handle all mapping instead of just dict #4602

Merged
merged 6 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,22 @@ def wrapper(*args, **kwargs):
StateIterBases = get_base_class(StateIterVar)


def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.

Args:
cls: The class to check.
cls_check: The class to check against.

Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False


def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.

Expand Down
49 changes: 29 additions & 20 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Iterable,
List,
Literal,
Mapping,
NoReturn,
Optional,
Set,
Expand Down Expand Up @@ -64,6 +65,7 @@
_isinstance,
get_origin,
has_args,
safe_issubclass,
unionize,
)

Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, VarData | None] | None = None,
hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
Expand Down Expand Up @@ -643,8 +645,8 @@ def to(
@overload
def to(
self,
output: type[dict],
) -> ObjectVar[dict]: ...
output: type[Mapping],
) -> ObjectVar[Mapping]: ...

@overload
def to(
Expand Down Expand Up @@ -686,7 +688,9 @@ def to(

# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types:
if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output)

if fixed_output_type is None:
Expand Down Expand Up @@ -820,7 +824,7 @@ def _get_default_value(self) -> Any:
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
if issubclass(type_, Mapping):
return {}
if issubclass(type_, tuple):
return ()
Expand Down Expand Up @@ -1026,7 +1030,7 @@ def _as_ref(self) -> Var:
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(ObjectVar, Dict[str, str])
).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))]

@deprecated("Use `.js_type()` instead.")
Expand Down Expand Up @@ -1373,7 +1377,7 @@ def create(

serialized_value = serializers.serialize(value)
if serialized_value is not None:
if isinstance(serialized_value, dict):
if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create(
serialized_value,
_var_type=type(value),
Expand Down Expand Up @@ -1498,7 +1502,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ...


OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)


@overload
Expand Down Expand Up @@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
if isinstance(value, Mapping):
return Mapping[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
Expand Down Expand Up @@ -2002,10 +2006,10 @@ def __get__(

@overload
def __get__(
self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...

@overload
def __get__(
Expand Down Expand Up @@ -2924,11 +2928,14 @@ def dispatch(

BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)

FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)


class Field(Generic[T]):
class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE."""

def __set__(self, instance, value: T):
def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var.

Args:
Expand All @@ -2940,7 +2947,9 @@ def __set__(self, instance, value: T):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...

@overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
) -> NumberVar: ...

@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
Expand All @@ -2957,19 +2966,19 @@ def __get__(

@overload
def __get__(
self: Field[Dict[str, V]], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ...
self: Field[MAPPING_TYPE], instance: None, owner
) -> ObjectVar[MAPPING_TYPE]: ...

@overload
def __get__(
self: Field[BASE_TYPE], instance: None, owner
) -> ObjectVar[BASE_TYPE]: ...

@overload
def __get__(self, instance: None, owner) -> Var[T]: ...
def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...

@overload
def __get__(self, instance, owner) -> T: ...
def __get__(self, instance, owner) -> FIELD_TYPE: ...

def __get__(self, instance, owner): # type: ignore
"""Get the Var.
Expand All @@ -2980,7 +2989,7 @@ def __get__(self, instance, owner): # type: ignore
"""


def field(value: T) -> Field[T]:
def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value.

Args:
Expand Down
Loading
Loading