Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Draft] Migrate pydantic v2 #376

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions meerkat/interactive/app/src/lib/component/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import typing
import uuid
import warnings
from typing import Dict, List, Set
from typing import Any, ClassVar, Dict, List, Optional, Set

from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, ConfigDict, Extra, root_validator, model_validator

from meerkat.constants import MEERKAT_NPM_PACKAGE, PathHelper
from meerkat.dataframe import DataFrame
Expand Down Expand Up @@ -146,6 +146,12 @@ class BaseComponent(
PythonToSvelteMixin,
BaseModel,
):
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
_cache: ClassVar[Optional[Dict[str, Any]]] = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand Down Expand Up @@ -196,24 +202,24 @@ def component_name(cls):

return cls.__name__

@classproperty
@classmethod
def event_names(cls) -> List[str]:
"""Returns a list of event names that this component emits."""
return [
k[3:]
for k in cls.__fields__
for k in cls.model_fields
if k.startswith("on_")
and not issubclass(cls.__fields__[k].type_, EndpointProperty)
and not EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation)
]

@classproperty
def events(cls) -> List[str]:
"""Returns a list of events that this component emits."""
return [
k
for k in cls.__fields__
for k in cls.model_fields
if k.startswith("on_")
and not issubclass(cls.__fields__[k].type_, EndpointProperty)
and not EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation)
]

@classproperty
Expand Down Expand Up @@ -262,30 +268,30 @@ def path(cls):
"property of the BaseComponent correctly."
)

@classproperty
@classmethod
def prop_names(cls):
return [
k for k in cls.__fields__ if not k.startswith("on_") and "_self_id" != k
k for k in cls.model_fields if not k.startswith("on_") and "_self_id" != k
] + [
k
for k in cls.__fields__
for k in cls.model_fields
if k.startswith("on_")
and issubclass(cls.__fields__[k].type_, EndpointProperty)
and EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation)
]

@classproperty
@classmethod
def prop_bindings(cls):
if not issubclass(cls, Component):
# These props need to be bound with `bind:` in Svelte
types_to_bind = {Store, DataFrame}
return {
prop: cls.__fields__[prop].type_ in types_to_bind
for prop in cls.prop_names
prop: cls.model_fields[prop].annotation in types_to_bind
for prop in cls.prop_names()
}
else:
return {
prop: (cls.__fields__[prop].type_ != EndpointProperty)
for prop in cls.prop_names
prop: not EndpointProperty.is_endpoint_property(cls.model_fields[prop].annotation)
for prop in cls.prop_names()
}

@property
Expand Down Expand Up @@ -315,23 +321,23 @@ def _frontend(value):

@property
def props(self):
return {k: self.__getattribute__(k) for k in self.prop_names}
return {k: self.__getattribute__(k) for k in self.prop_names()}

@property
def virtual_props(self):
"""Props, and all events (as_*) as props."""
vprop_names = [k for k in self.__fields__ if "_self_id" != k] + ["component_id"]
return {k: self.__getattribute__(k) for k in vprop_names}

@root_validator(pre=True)
@model_validator(mode="before")
def _init_cache(cls, values):
# This is a workaround because Pydantic automatically converts
# all Store objects to their underlying values when validating
# the class. We need to keep the Store objects around.
cls._cache = values.copy()
return values

@root_validator(pre=True)
@model_validator(mode="before")
def _endpoint_name_starts_with_on(cls, values):
"""Make sure that all `Endpoint` fields have a name that starts with
`on_`."""
Expand Down Expand Up @@ -386,7 +392,7 @@ def _get_event_interface_from_typehint(type_hint):
return out
return None

@root_validator(pre=True)
@model_validator(mode="before")
def _endpoint_signature_matches(cls, values):
"""Make sure that the signature of the Endpoint that is passed in
matches the parameter names and types that are sent from Svelte.
Expand Down Expand Up @@ -484,7 +490,7 @@ def _endpoint_signature_matches(cls, values):

return values

@root_validator(pre=False)
@model_validator(mode="before")
def _update_cache(cls, values):
# `cls._cache` only contains the values that were passed in
# `values` contains all the values, including the ones that
Expand All @@ -509,7 +515,7 @@ def _update_cache(cls, values):
pass
return values

@root_validator(pre=False)
@model_validator(mode="before")
def _check_inode(cls, values):
"""Unwrap NodeMixin objects to their underlying Node (except
Stores)."""
Expand Down Expand Up @@ -558,15 +564,21 @@ def _ipython_display_(self):
).launch()
)

class Config:
arbitrary_types_allowed = True
extra = Extra.allow
copy_on_model_validation = False
# class Config:
# arbitrary_types_allowed = True
# extra = Extra.allow
# copy_on_model_validation = False


class Component(BaseComponent):
"""Component with simple defaults."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
ignored_types=(Endpoint, EndpointProperty),
)

@classproperty
def component_name(cls):
# Inheriting an existing Component and modifying it on the Python side
Expand All @@ -576,7 +588,7 @@ def component_name(cls):

return cls.__name__

@root_validator(pre=True)
@model_validator(mode="before")
def _init_cache(cls, values):
# This is a workaround because Pydantic automatically converts
# all Store objects to their underlying values when validating
Expand All @@ -595,18 +607,19 @@ def _init_cache(cls, values):

return values

@root_validator(pre=False)
@model_validator(mode="before")
def _convert_fields(cls, values: dict):
values = cls._cache
print("values", values)
cls._cache = None
for name, value in values.items():
# Wrap all the fields that are not NodeMixins in a Store
# (i.e. this will exclude DataFrame, Endpoint etc. as well as
# fields that are already Stores)
if (
name not in cls.__fields__
or cls.__fields__[name].type_ == Endpoint
or cls.__fields__[name].type_ == EndpointProperty
name not in cls.model_fields
or cls.model_fields[name].annotation == Endpoint
or cls.model_fields[name].annotation == EndpointProperty
):
# Separately skip Endpoint fields by looking at the field type,
# since they are assigned None by default and would be missed
Expand Down
25 changes: 20 additions & 5 deletions meerkat/interactive/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import typing
from functools import partial, wraps
from typing import Any, Callable, Generic, Union
from typing import Any, Callable, Generic, Optional, Union

from fastapi import APIRouter, Body
from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -496,7 +496,22 @@ def validate(cls, v):


class EndpointProperty(Endpoint, Generic[T]):
pass
@classmethod
def is_endpoint_property(cls, type_hint):
"""Check if a type hint is an EndpointProperty or a Union of EndpointProperties."""
if isinstance(type_hint, type) and issubclass(type_hint, EndpointProperty):
return True

if isinstance(type_hint, typing._GenericAlias):
origin = get_type_hint_origin(type_hint)
args = get_type_hint_args(type_hint)

if origin == typing.Union:
return any(cls.is_endpoint_property(arg) for arg in args)
elif isinstance(origin, type):
return issubclass(origin, EndpointProperty)

return False


def make_endpoint(endpoint_or_fn: Union[Callable, Endpoint, None]) -> Endpoint:
Expand All @@ -509,9 +524,9 @@ def make_endpoint(endpoint_or_fn: Union[Callable, Endpoint, None]) -> Endpoint:


def endpoint(
fn: Callable = None,
prefix: Union[str, APIRouter] = None,
route: str = None,
fn: Optional[Callable] = None,
prefix: Optional[Union[str, APIRouter]] = None,
route: Optional[str] = None,
method: str = "POST",
) -> Endpoint:
"""Decorator to mark a function as an endpoint.
Expand Down
8 changes: 4 additions & 4 deletions meerkat/interactive/formatter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ class Formatter(BaseFormatter):
# TODO: set the signature of the __init__ so it works with autocomplete and docs
def __init__(self, **kwargs):
for k in kwargs:
if k not in self.component_class.prop_names:
if k not in self.component_class.prop_names():
raise ValueError(f"{k} is not a valid prop for {self.component_class}")

for prop_name, field in self.component_class.__fields__.items():
if field.name != self.data_prop and prop_name not in kwargs:
if field.required:
for prop_name, field in self.component_class.model_fields.items():
if prop_name != self.data_prop and prop_name not in kwargs:
if field.is_required():
raise ValueError("""Missing required argument.""")
kwargs[prop_name] = field.default
self._props = kwargs
Expand Down
2 changes: 1 addition & 1 deletion meerkat/interactive/graph/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, ValidationError
from pydantic.fields import ModelField
from pydantic.v1.fields import ModelField
from wrapt import ObjectProxy

from meerkat.interactive.graph.magic import _magic, is_magic_context
Expand Down
10 changes: 6 additions & 4 deletions meerkat/interactive/svelte.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,20 @@ def render_component_wrapper(self, component: Type[BaseComponent]):
from meerkat.interactive.startup import snake_case_to_camel_case

prop_names_camel_case = [
snake_case_to_camel_case(prop_name) for prop_name in component.prop_names
snake_case_to_camel_case(prop_name) for prop_name in component.prop_names()
]

print(component)

return template.render(
import_style=component.wrapper_import_style,
component_name=component.component_name,
path=component.path,
prop_names=component.prop_names,
prop_names=component.prop_names(),
prop_names_camel_case=prop_names_camel_case,
event_names=component.event_names,
event_names=component.event_names(),
use_bindings=True,
prop_bindings=component.prop_bindings,
prop_bindings=component.prop_bindings(),
slottable=component.slottable,
zip=zip,
is_user_app=self.app.is_user_app,
Expand Down
Loading