From 07d0f1e11ca22c65916bcda8c7eb59e01fc89776 Mon Sep 17 00:00:00 2001 From: Arjun Desai Date: Mon, 23 Dec 2024 21:59:28 -0800 Subject: [PATCH 1/2] Upgrade pydantic --- .../app/src/lib/component/abstract.py | 36 ++++++++++++------- meerkat/interactive/endpoint.py | 8 ++--- meerkat/interactive/graph/store.py | 2 +- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/meerkat/interactive/app/src/lib/component/abstract.py b/meerkat/interactive/app/src/lib/component/abstract.py index 2dd08d580..c712de03e 100644 --- a/meerkat/interactive/app/src/lib/component/abstract.py +++ b/meerkat/interactive/app/src/lib/component/abstract.py @@ -6,7 +6,7 @@ import warnings from typing import Dict, List, Set -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, ConfigDict, Extra, root_validator, model_validator from meerkat.constants import MEERKAT_NPM_PACKAGE, PathHelper from meerkat.dataframe import DataFrame @@ -146,6 +146,12 @@ class BaseComponent( PythonToSvelteMixin, BaseModel, ): + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="allow", + copy_on_model_validation="none" + ) + def __init__(self, **kwargs): super().__init__(**kwargs) @@ -323,7 +329,7 @@ def virtual_props(self): vprop_names = [k for k in self.__fields__ if "_self_id" != k] + ["component_id"] return {k: self.__getattribute__(k) for k in vprop_names} - @root_validator(pre=True) + @model_validator(mode="before") def _init_cache(cls, values): # This is a workaround because Pydantic automatically converts # all Store objects to their underlying values when validating @@ -331,7 +337,7 @@ def _init_cache(cls, values): cls._cache = values.copy() return values - @root_validator(pre=True) + @model_validator(mode="before") def _endpoint_name_starts_with_on(cls, values): """Make sure that all `Endpoint` fields have a name that starts with `on_`.""" @@ -386,7 +392,7 @@ def _get_event_interface_from_typehint(type_hint): return out return None - @root_validator(pre=True) + @model_validator(mode="before") def _endpoint_signature_matches(cls, values): """Make sure that the signature of the Endpoint that is passed in matches the parameter names and types that are sent from Svelte. @@ -484,7 +490,7 @@ def _endpoint_signature_matches(cls, values): return values - @root_validator(pre=False) + @model_validator(mode="before") def _update_cache(cls, values): # `cls._cache` only contains the values that were passed in # `values` contains all the values, including the ones that @@ -509,7 +515,7 @@ def _update_cache(cls, values): pass return values - @root_validator(pre=False) + @model_validator(mode="before") def _check_inode(cls, values): """Unwrap NodeMixin objects to their underlying Node (except Stores).""" @@ -558,15 +564,21 @@ def _ipython_display_(self): ).launch() ) - class Config: - arbitrary_types_allowed = True - extra = Extra.allow - copy_on_model_validation = False + # class Config: + # arbitrary_types_allowed = True + # extra = Extra.allow + # copy_on_model_validation = False class Component(BaseComponent): """Component with simple defaults.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="allow", + ignored_types=(Endpoint, EndpointProperty), + ) + @classproperty def component_name(cls): # Inheriting an existing Component and modifying it on the Python side @@ -576,7 +588,7 @@ def component_name(cls): return cls.__name__ - @root_validator(pre=True) + @model_validator(mode="before") def _init_cache(cls, values): # This is a workaround because Pydantic automatically converts # all Store objects to their underlying values when validating @@ -595,7 +607,7 @@ def _init_cache(cls, values): return values - @root_validator(pre=False) + @model_validator(mode="before") def _convert_fields(cls, values: dict): values = cls._cache cls._cache = None diff --git a/meerkat/interactive/endpoint.py b/meerkat/interactive/endpoint.py index 0958f89f8..670b97622 100644 --- a/meerkat/interactive/endpoint.py +++ b/meerkat/interactive/endpoint.py @@ -4,7 +4,7 @@ import logging import typing from functools import partial, wraps -from typing import Any, Callable, Generic, Union +from typing import Any, Callable, Generic, Optional, Union from fastapi import APIRouter, Body from pydantic import BaseModel, create_model @@ -509,9 +509,9 @@ def make_endpoint(endpoint_or_fn: Union[Callable, Endpoint, None]) -> Endpoint: def endpoint( - fn: Callable = None, - prefix: Union[str, APIRouter] = None, - route: str = None, + fn: Optional[Callable] = None, + prefix: Optional[Union[str, APIRouter]] = None, + route: Optional[str] = None, method: str = "POST", ) -> Endpoint: """Decorator to mark a function as an endpoint. diff --git a/meerkat/interactive/graph/store.py b/meerkat/interactive/graph/store.py index 6a9a8fd11..ad1c4bcad 100644 --- a/meerkat/interactive/graph/store.py +++ b/meerkat/interactive/graph/store.py @@ -4,7 +4,7 @@ from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, ValidationError -from pydantic.fields import ModelField +from pydantic.v1.fields import ModelField from wrapt import ObjectProxy from meerkat.interactive.graph.magic import _magic, is_magic_context From 32a3c48a91922c4f807d489188dc8a81682a3f37 Mon Sep 17 00:00:00 2001 From: Arjun Desai Date: Tue, 24 Dec 2024 07:06:24 -0800 Subject: [PATCH 2/2] debug --- .../app/src/lib/component/abstract.py | 41 ++++++++++--------- meerkat/interactive/endpoint.py | 17 +++++++- meerkat/interactive/formatter/base.py | 8 ++-- meerkat/interactive/svelte.py | 10 +++-- 4 files changed, 47 insertions(+), 29 deletions(-) diff --git a/meerkat/interactive/app/src/lib/component/abstract.py b/meerkat/interactive/app/src/lib/component/abstract.py index c712de03e..705c8ff2a 100644 --- a/meerkat/interactive/app/src/lib/component/abstract.py +++ b/meerkat/interactive/app/src/lib/component/abstract.py @@ -4,7 +4,7 @@ import typing import uuid import warnings -from typing import Dict, List, Set +from typing import Any, ClassVar, Dict, List, Optional, Set from pydantic import BaseModel, ConfigDict, Extra, root_validator, model_validator @@ -149,8 +149,8 @@ class BaseComponent( model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", - copy_on_model_validation="none" ) + _cache: ClassVar[Optional[Dict[str, Any]]] = {} def __init__(self, **kwargs): super().__init__(**kwargs) @@ -202,14 +202,14 @@ def component_name(cls): return cls.__name__ - @classproperty + @classmethod def event_names(cls) -> List[str]: """Returns a list of event names that this component emits.""" return [ k[3:] - for k in cls.__fields__ + for k in cls.model_fields if k.startswith("on_") - and not issubclass(cls.__fields__[k].type_, EndpointProperty) + and not EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation) ] @classproperty @@ -217,9 +217,9 @@ def events(cls) -> List[str]: """Returns a list of events that this component emits.""" return [ k - for k in cls.__fields__ + for k in cls.model_fields if k.startswith("on_") - and not issubclass(cls.__fields__[k].type_, EndpointProperty) + and not EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation) ] @classproperty @@ -268,30 +268,30 @@ def path(cls): "property of the BaseComponent correctly." ) - @classproperty + @classmethod def prop_names(cls): return [ - k for k in cls.__fields__ if not k.startswith("on_") and "_self_id" != k + k for k in cls.model_fields if not k.startswith("on_") and "_self_id" != k ] + [ k - for k in cls.__fields__ + for k in cls.model_fields if k.startswith("on_") - and issubclass(cls.__fields__[k].type_, EndpointProperty) + and EndpointProperty.is_endpoint_property(cls.model_fields[k].annotation) ] - @classproperty + @classmethod def prop_bindings(cls): if not issubclass(cls, Component): # These props need to be bound with `bind:` in Svelte types_to_bind = {Store, DataFrame} return { - prop: cls.__fields__[prop].type_ in types_to_bind - for prop in cls.prop_names + prop: cls.model_fields[prop].annotation in types_to_bind + for prop in cls.prop_names() } else: return { - prop: (cls.__fields__[prop].type_ != EndpointProperty) - for prop in cls.prop_names + prop: not EndpointProperty.is_endpoint_property(cls.model_fields[prop].annotation) + for prop in cls.prop_names() } @property @@ -321,7 +321,7 @@ def _frontend(value): @property def props(self): - return {k: self.__getattribute__(k) for k in self.prop_names} + return {k: self.__getattribute__(k) for k in self.prop_names()} @property def virtual_props(self): @@ -610,15 +610,16 @@ def _init_cache(cls, values): @model_validator(mode="before") def _convert_fields(cls, values: dict): values = cls._cache + print("values", values) cls._cache = None for name, value in values.items(): # Wrap all the fields that are not NodeMixins in a Store # (i.e. this will exclude DataFrame, Endpoint etc. as well as # fields that are already Stores) if ( - name not in cls.__fields__ - or cls.__fields__[name].type_ == Endpoint - or cls.__fields__[name].type_ == EndpointProperty + name not in cls.model_fields + or cls.model_fields[name].annotation == Endpoint + or cls.model_fields[name].annotation == EndpointProperty ): # Separately skip Endpoint fields by looking at the field type, # since they are assigned None by default and would be missed diff --git a/meerkat/interactive/endpoint.py b/meerkat/interactive/endpoint.py index 670b97622..1fffa3aaa 100644 --- a/meerkat/interactive/endpoint.py +++ b/meerkat/interactive/endpoint.py @@ -496,7 +496,22 @@ def validate(cls, v): class EndpointProperty(Endpoint, Generic[T]): - pass + @classmethod + def is_endpoint_property(cls, type_hint): + """Check if a type hint is an EndpointProperty or a Union of EndpointProperties.""" + if isinstance(type_hint, type) and issubclass(type_hint, EndpointProperty): + return True + + if isinstance(type_hint, typing._GenericAlias): + origin = get_type_hint_origin(type_hint) + args = get_type_hint_args(type_hint) + + if origin == typing.Union: + return any(cls.is_endpoint_property(arg) for arg in args) + elif isinstance(origin, type): + return issubclass(origin, EndpointProperty) + + return False def make_endpoint(endpoint_or_fn: Union[Callable, Endpoint, None]) -> Endpoint: diff --git a/meerkat/interactive/formatter/base.py b/meerkat/interactive/formatter/base.py index 3ae30f434..7655e2964 100644 --- a/meerkat/interactive/formatter/base.py +++ b/meerkat/interactive/formatter/base.py @@ -296,12 +296,12 @@ class Formatter(BaseFormatter): # TODO: set the signature of the __init__ so it works with autocomplete and docs def __init__(self, **kwargs): for k in kwargs: - if k not in self.component_class.prop_names: + if k not in self.component_class.prop_names(): raise ValueError(f"{k} is not a valid prop for {self.component_class}") - for prop_name, field in self.component_class.__fields__.items(): - if field.name != self.data_prop and prop_name not in kwargs: - if field.required: + for prop_name, field in self.component_class.model_fields.items(): + if prop_name != self.data_prop and prop_name not in kwargs: + if field.is_required(): raise ValueError("""Missing required argument.""") kwargs[prop_name] = field.default self._props = kwargs diff --git a/meerkat/interactive/svelte.py b/meerkat/interactive/svelte.py index f77bf8a45..fad6862d7 100644 --- a/meerkat/interactive/svelte.py +++ b/meerkat/interactive/svelte.py @@ -255,18 +255,20 @@ def render_component_wrapper(self, component: Type[BaseComponent]): from meerkat.interactive.startup import snake_case_to_camel_case prop_names_camel_case = [ - snake_case_to_camel_case(prop_name) for prop_name in component.prop_names + snake_case_to_camel_case(prop_name) for prop_name in component.prop_names() ] + + print(component) return template.render( import_style=component.wrapper_import_style, component_name=component.component_name, path=component.path, - prop_names=component.prop_names, + prop_names=component.prop_names(), prop_names_camel_case=prop_names_camel_case, - event_names=component.event_names, + event_names=component.event_names(), use_bindings=True, - prop_bindings=component.prop_bindings, + prop_bindings=component.prop_bindings(), slottable=component.slottable, zip=zip, is_user_app=self.app.is_user_app,