Skip to content

Commit

Permalink
StateManagerRedis: clean up commentary
Browse files Browse the repository at this point in the history
  • Loading branch information
masenf committed Feb 10, 2024
1 parent 548a31e commit cdf4593
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,7 @@ async def get_state(self, token: str) -> BaseState:
Returns:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
if token not in self.states:
self.states[token] = self.state()
Expand All @@ -1713,6 +1714,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
if token not in self._states_locks:
async with self._state_manager_lock:
Expand Down Expand Up @@ -1774,61 +1776,69 @@ async def get_state(
Raises:
RuntimeError: when the state_cls is not specified in the token
"""
# Split the actual token from the fully qualified substate name.
client_token, _, state_path = token.partition("_")
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
)

# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)

if redis_state is not None:
# Deserialize the substate.
state = cloudpickle.loads(redis_state)

# populate parent and substates
# Populate parent and substates if requested.
if parent_state is None:
# retrieve the parent state
# Retrieve the parent state from redis.
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
parent_state_key = token.rpartition(".")[0]
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
# Set up Bidirectional linkage
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
if get_substates:
# retrieve all substates
# Retrieve all substates from redis.
for substate_cls in state_cls.get_substates():
substate_name = substate_cls.get_name()
substate_key = token + "." + substate_name
state.substates[substate_name] = await self.get_state(
substate_key, top_level=False, parent_state=state
)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
while type(state) != self.state and state.parent_state is not None:
state = state.parent_state
return state

# Have to create a new entry for this token
# Key didn't exist so we have to create a new entry for this token.
if parent_state is None:
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
# retrieve the parent state to initialize event handlers
# Retrieve the parent state to populate event handlers onto this substate.
parent_state_key = client_token + "_" + parent_state_name
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
# Persist the new state class to redis.
await self.set_state(
token,
state_cls(
parent_state=parent_state,
_client_token=client_token,
init_substates=False,
),
)
# After creating the state key, recursively call `get_state` to populate substates.
return await self.get_state(
token,
top_level=top_level,
Expand Down Expand Up @@ -1856,7 +1866,7 @@ async def set_state(
Raises:
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
"""
# check that we're holding the lock
# Check that we're holding the lock.
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
Expand All @@ -1866,9 +1876,11 @@ async def set_state(
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.background` decorator for long-running tasks."
)
# Find the substate associated with the token.
state_path = token.partition("_")[2]
if state_path and state.get_full_name() != state_path:
state = state.get_substate(tuple(state_path.split(".")))
# Persist the parent state separately, if requested.
if state.parent_state is not None and set_parent_state:
parent_state_key = token.rpartition(".")[0]
await self.set_state(
Expand All @@ -1877,12 +1889,14 @@ async def set_state(
lock_id=lock_id,
set_substates=False,
)
# Persist the substates separately, if requested.
if set_substates:
for substate_name, substate in state.substates.items():
substate_key = token + "." + substate_name
await self.set_state(
substate_key, substate, lock_id=lock_id, set_parent_state=False
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -1910,6 +1924,7 @@ def _lock_key(token: str) -> bytes:
Returns:
The redis lock key for the token.
"""
# All substates share the same lock domain, so ignore any substate path suffix.
client_token = token.partition("_")[0]
return f"{client_token}_lock".encode()

Expand Down

0 comments on commit cdf4593

Please # to comment.