diff --git a/reflex/state.py b/reflex/state.py index 61a561e856c..acaa4dc0350 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1689,6 +1689,7 @@ async def get_state(self, token: str) -> BaseState: Returns: The state for the token. """ + # Memory state manager ignores the substate suffix and always returns the top-level state. token = token.partition("_")[0] if token not in self.states: self.states[token] = self.state() @@ -1713,6 +1714,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: Yields: The state for the token. """ + # Memory state manager ignores the substate suffix and always returns the top-level state. token = token.partition("_")[0] if token not in self._states_locks: async with self._state_manager_lock: @@ -1774,61 +1776,69 @@ async def get_state( Raises: RuntimeError: when the state_cls is not specified in the token """ + # Split the actual token from the fully qualified substate name. client_token, _, state_path = token.partition("_") if state_path: + # Get the State class associated with the given path. state_cls = self.state.get_class_substate(tuple(state_path.split("."))) else: raise RuntimeError( "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" ) + + # Fetch the serialized substate from redis. redis_state = await self.redis.get(token) if redis_state is not None: + # Deserialize the substate. state = cloudpickle.loads(redis_state) - # populate parent and substates + # Populate parent and substates if requested. if parent_state is None: - # retrieve the parent state + # Retrieve the parent state from redis. parent_state_name = state_path.rpartition(".")[0] if parent_state_name: parent_state_key = token.rpartition(".")[0] parent_state = await self.get_state( parent_state_key, top_level=False, get_substates=False ) - # Set up Bidirectional linkage + # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: parent_state.substates[state.get_name()] = state state.parent_state = parent_state if get_substates: - # retrieve all substates + # Retrieve all substates from redis. for substate_cls in state_cls.get_substates(): substate_name = substate_cls.get_name() substate_key = token + "." + substate_name state.substates[substate_name] = await self.get_state( substate_key, top_level=False, parent_state=state ) + # To retain compatibility with previous implementation, by default, we return + # the top-level state by chasing `parent_state` pointers up the tree. if top_level: while type(state) != self.state and state.parent_state is not None: state = state.parent_state return state - # Have to create a new entry for this token + # Key didn't exist so we have to create a new entry for this token. if parent_state is None: parent_state_name = state_path.rpartition(".")[0] if parent_state_name: - # retrieve the parent state to initialize event handlers + # Retrieve the parent state to populate event handlers onto this substate. parent_state_key = client_token + "_" + parent_state_name parent_state = await self.get_state( parent_state_key, top_level=False, get_substates=False ) + # Persist the new state class to redis. await self.set_state( token, state_cls( parent_state=parent_state, - _client_token=client_token, init_substates=False, ), ) + # After creating the state key, recursively call `get_state` to populate substates. return await self.get_state( token, top_level=top_level, @@ -1856,7 +1866,7 @@ async def set_state( Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. """ - # check that we're holding the lock + # Check that we're holding the lock. if ( lock_id is not None and await self.redis.get(self._lock_key(token)) != lock_id @@ -1866,9 +1876,11 @@ async def set_state( f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.background` decorator for long-running tasks." ) + # Find the substate associated with the token. state_path = token.partition("_")[2] if state_path and state.get_full_name() != state_path: state = state.get_substate(tuple(state_path.split("."))) + # Persist the parent state separately, if requested. if state.parent_state is not None and set_parent_state: parent_state_key = token.rpartition(".")[0] await self.set_state( @@ -1877,12 +1889,14 @@ async def set_state( lock_id=lock_id, set_substates=False, ) + # Persist the substates separately, if requested. if set_substates: for substate_name, substate in state.substates.items(): substate_key = token + "." + substate_name await self.set_state( substate_key, substate, lock_id=lock_id, set_parent_state=False ) + # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) @contextlib.asynccontextmanager @@ -1910,6 +1924,7 @@ def _lock_key(token: str) -> bytes: Returns: The redis lock key for the token. """ + # All substates share the same lock domain, so ignore any substate path suffix. client_token = token.partition("_")[0] return f"{client_token}_lock".encode()