Skip to content

Commit

Permalink
Add methods to MongoDBPersister and migrate
Browse files Browse the repository at this point in the history
Migrates the MongoDBPersister to b_pymongo.py for consistent naming
convention.
  • Loading branch information
jernejfrank committed Feb 1, 2025
1 parent 6c48452 commit a89c1c0
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 124 deletions.
136 changes: 13 additions & 123 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,23 @@
import json
"""This module will be deprecated. Please use b_pymongo.py for imports."""

import logging
from datetime import datetime, timezone
from typing import Literal, Optional

from pymongo import MongoClient

from burr.core import persistence, state
from burr.integrations.persisters.b_pymongo import MongoDBBasePersister as PymongoPersister

logger = logging.getLogger(__name__)

logger.warning(
"This class is deprecated and has been moved. "
"Please import MongoDBBasePersister from b_pymongo.py."
)

class MongoDBBasePersister(persistence.BaseStatePersister):
"""A class used to represent a MongoDB Persister.
Example usage:
.. code-block:: python

persister = MongoDBBasePersister.from_values(uri='mongodb://user:pass@localhost:27017',
db_name='mydatabase',
collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
sequence_id=1,
position='example_position',
state=state.State({'key': 'value'}),
status='completed'
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)
class MongoDBBasePersister(PymongoPersister):
"""A class used to represent the MongoDB Persister.
Note: this is called MongoDBBasePersister because we had to change the constructor and wanted to make
this change backwards compatible.
This class is deprecated and has been moved to b_pymongo.py.
"""

@classmethod
Expand All @@ -45,114 +30,19 @@ def from_values(
mongo_client_kwargs: dict = None,
) -> "MongoDBBasePersister":
"""Initializes the MongoDBBasePersister class."""

if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
return cls(
return PymongoPersister(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)

def __init__(
self,
client,
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
):
"""Initializes the MongoDBBasePersister class.
:param client: the mongodb client to use
:param db_name: the name of the database to use
:param collection_name: the name of the collection to use
:param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object
"""
self.client = client
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
app_ids = self.collection.distinct("app_id", {"partition_key": partition_key})
return app_ids

def load(
self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs
) -> Optional[persistence.PersistedStateData]:
"""Load the state data for a given partition key, app id, and sequence id."""
query = {"partition_key": partition_key, "app_id": app_id}
if sequence_id is not None:
query["sequence_id"] = sequence_id
document = self.collection.find_one(query, sort=[("sequence_id", -1)])
if not document:
return None
_state = state.State.deserialize(json.loads(document["state"]), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": document["sequence_id"],
"position": document["position"],
"state": _state,
"created_at": document["created_at"],
"status": document["status"],
}

def save(
self,
partition_key: str,
app_id: str,
sequence_id: int,
position: str,
state: state.State,
status: Literal["completed", "failed"],
**kwargs,
):
"""Save the state data to the MongoDB database."""
key = {"partition_key": partition_key, "app_id": app_id, "sequence_id": sequence_id}
if self.collection.find_one(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.serialize(**self.serde_kwargs))
self.collection.insert_one(
{
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": position,
"state": json_state,
"status": status,
"created_at": datetime.now(timezone.utc).isoformat(),
}
)

def __del__(self):
self.client.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = {
"uri": self.client.address[0],
"port": self.client.address[1],
"db_name": self.db.name,
"collection_name": self.collection.name,
}
del state["client"]
del state["db"]
del state["collection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume MongoClient.
self.client = MongoClient(connection_params["uri"], connection_params["port"])
self.db = self.client[connection_params["db_name"]]
self.collection = self.db[connection_params["collection_name"]]
self.__dict__.update(state)


class MongoDBPersister(MongoDBBasePersister):
class MongoDBPersister(PymongoPersister):
"""A class used to represent a MongoDB Persister.
This class is deprecated. Please use MongoDBBasePersister instead.
Expand Down
176 changes: 176 additions & 0 deletions burr/integrations/persisters/b_pymongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import json
import logging
from datetime import datetime, timezone
from typing import Literal, Optional

from pymongo import MongoClient

from burr.core import persistence, state

logger = logging.getLogger(__name__)


class MongoDBBasePersister(persistence.BaseStatePersister):
"""A class used to represent a MongoDB Persister.
Example usage:
.. code-block:: python
persister = MongoDBBasePersister.from_values(uri='mongodb://user:pass@localhost:27017',
db_name='mydatabase',
collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
sequence_id=1,
position='example_position',
state=state.State({'key': 'value'}),
status='completed'
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)
Note: this is called MongoDBBasePersister because we had to change the constructor and wanted to make
this change backwards compatible.
"""

@classmethod
def from_config(cls, config: dict) -> "MongoDBBasePersister":
"""Creates a new instance of the MongoDBBasePersister from a configuration dictionary."""
return cls.from_values(**config)

@classmethod
def from_values(
cls,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
mongo_client_kwargs: dict = None,
) -> "MongoDBBasePersister":
"""Initializes the MongoDBBasePersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
return cls(
client=client,
db_name=db_name,
collection_name=collection_name,
serde_kwargs=serde_kwargs,
)

def __init__(
self,
client,
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
):
"""Initializes the MongoDBBasePersister class.
:param client: the mongodb client to use
:param db_name: the name of the database to use
:param collection_name: the name of the collection to use
:param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object
"""
self.client = client
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.connection.close()
return False

def set_serde_kwargs(self, serde_kwargs: dict):
"""Sets the serde_kwargs for the persister."""
self.serde_kwargs = serde_kwargs

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
app_ids = self.collection.distinct("app_id", {"partition_key": partition_key})
return app_ids

def load(
self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs
) -> Optional[persistence.PersistedStateData]:
"""Load the state data for a given partition key, app id, and sequence id."""
query = {"partition_key": partition_key, "app_id": app_id}
if sequence_id is not None:
query["sequence_id"] = sequence_id
document = self.collection.find_one(query, sort=[("sequence_id", -1)])
if not document:
return None
_state = state.State.deserialize(json.loads(document["state"]), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": document["sequence_id"],
"position": document["position"],
"state": _state,
"created_at": document["created_at"],
"status": document["status"],
}

def save(
self,
partition_key: str,
app_id: str,
sequence_id: int,
position: str,
state: state.State,
status: Literal["completed", "failed"],
**kwargs,
):
"""Save the state data to the MongoDB database."""
key = {"partition_key": partition_key, "app_id": app_id, "sequence_id": sequence_id}
if self.collection.find_one(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.serialize(**self.serde_kwargs))
self.collection.insert_one(
{
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": position,
"state": json_state,
"status": status,
"created_at": datetime.now(timezone.utc).isoformat(),
}
)

def cleanup(self):
"""Closes the connection to the database."""
self.connection.close()

def __del__(self):
# This should be deprecated -- using __del__ is unreliable for closing connections to db's;
# the preferred way should be for the user to use a context manager or use the `.cleanup()`
# method within a REST API framework.

self.client.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = {
"uri": self.client.address[0],
"port": self.client.address[1],
"db_name": self.db.name,
"collection_name": self.collection.name,
}
del state["client"]
del state["db"]
del state["collection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume MongoClient.
self.client = MongoClient(connection_params["uri"], connection_params["port"])
self.db = self.client[connection_params["db_name"]]
self.collection = self.db[connection_params["collection_name"]]
self.__dict__.update(state)
3 changes: 2 additions & 1 deletion tests/integrations/persisters/test_b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_mongodb import MongoDBBasePersister, MongoDBPersister
from burr.integrations.persisters.b_mongodb import MongoDBPersister
from burr.integrations.persisters.b_pymongo import MongoDBBasePersister

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)
Expand Down

0 comments on commit a89c1c0

Please # to comment.