Skip to content

Commit e68df3e

Browse files
authored
Merge pull request #1 from nickfrev/save-with-cascade-include-nested
Multiple fixes for cascade=True save issues
2 parents 2cab8a0 + e6ef197 commit e68df3e

File tree

6 files changed

+206
-72
lines changed

6 files changed

+206
-72
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,5 @@ that much better:
263263
* Timothé Perez (https://github.com/AchilleAsh)
264264
* oleksandr-l5 (https://github.com/oleksandr-l5)
265265
* Ido Shraga (https://github.com/idoshr)
266+
* Nick Freville (https://github.com/nickfrev)
266267
* Terence Honles (https://github.com/terencehonles)

mongoengine/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"ComplexBaseField",
2828
"ObjectIdField",
2929
"GeoJsonBaseField",
30+
"SaveableBaseField",
3031
# metaclasses
3132
"DocumentMetaclass",
3233
"TopLevelDocumentMetaclass",

mongoengine/base/fields.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from mongoengine.common import _import_class
1414
from mongoengine.errors import DeprecatedError, ValidationError
1515

16-
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
16+
__all__ = ("BaseField", "SaveableBaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
1717

1818

1919
class BaseField:
@@ -259,7 +259,14 @@ def owner_document(self, owner_document):
259259
self._set_owner_document(owner_document)
260260

261261

262-
class ComplexBaseField(BaseField):
262+
class SaveableBaseField(BaseField):
263+
"""A base class that dictates a field has the ability to save.
264+
"""
265+
def save():
266+
pass
267+
268+
269+
class ComplexBaseField(SaveableBaseField):
263270
"""Handles complex fields, such as lists / dictionaries.
264271
265272
Allows for nesting of embedded documents inside complex types.
@@ -483,6 +490,16 @@ def validate(self, value):
483490
if self.required and not value:
484491
self.error("Field is required and cannot be empty")
485492

493+
def save(self, instance, **kwargs):
494+
Document = _import_class("Document")
495+
value = instance._data.get(self.name)
496+
497+
for ref in value:
498+
if isinstance(ref, SaveableBaseField):
499+
ref.save(self, **kwargs)
500+
elif isinstance(ref, Document):
501+
ref.save(**kwargs)
502+
486503
def prepare_query_value(self, op, value):
487504
return self.to_mongo(value)
488505

mongoengine/document.py

Lines changed: 86 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import re
22

33
import pymongo
4+
from bson import SON
45
from bson.dbref import DBRef
56
from pymongo.read_preferences import ReadPreference
67

78
from mongoengine import signals
89
from mongoengine.base import (
910
BaseDict,
1011
BaseDocument,
12+
SaveableBaseField,
1113
BaseList,
1214
DocumentMetaclass,
1315
EmbeddedDocumentList,
@@ -385,44 +387,34 @@ def save(
385387
the cascade save using cascade_kwargs which overwrites the
386388
existing kwargs with custom values.
387389
"""
388-
signal_kwargs = signal_kwargs or {}
389-
390-
if self._meta.get("abstract"):
391-
raise InvalidDocumentError("Cannot save an abstract document.")
392-
393-
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
394-
395-
if validate:
396-
self.validate(clean=clean)
397-
398-
if write_concern is None:
399-
write_concern = {}
390+
# Used to avoid saving a document that is already saving (infinite loops)
391+
# this can be caused by the cascade save and circular references
392+
if getattr(self, "_is_saving", False):
393+
return
394+
self._is_saving = True
400395

401-
doc_id = self.to_mongo(fields=[self._meta["id_field"]])
402-
created = "_id" not in doc_id or self._created or force_insert
396+
try:
397+
signal_kwargs = signal_kwargs or {}
403398

404-
signals.pre_save_post_validation.send(
405-
self.__class__, document=self, created=created, **signal_kwargs
406-
)
407-
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
408-
doc = self.to_mongo()
399+
if write_concern is None:
400+
write_concern = {}
409401

410-
if self._meta.get("auto_create_index", True):
411-
self.ensure_indexes()
412-
413-
try:
414-
# Save a new document or update an existing one
415-
if created:
416-
object_id = self._save_create(doc, force_insert, write_concern)
417-
else:
418-
object_id, created = self._save_update(
419-
doc, save_condition, write_concern
420-
)
402+
if self._meta.get("abstract"):
403+
raise InvalidDocumentError("Cannot save an abstract document.")
421404

405+
# Cascade save before validation to avoid child not existing errors
422406
if cascade is None:
423407
cascade = self._meta.get("cascade", False) or cascade_kwargs is not None
424408

409+
has_placeholder_saved = False
410+
425411
if cascade:
412+
# If a cascade will occur save a placeholder version of this document to
413+
# avoid issues with cyclic saves if this doc has not been created yet
414+
if self.id is None:
415+
self._save_place_holder(force_insert, write_concern)
416+
has_placeholder_saved = True
417+
426418
kwargs = {
427419
"force_insert": force_insert,
428420
"validate": validate,
@@ -434,31 +426,74 @@ def save(
434426
kwargs["_refs"] = _refs
435427
self.cascade_save(**kwargs)
436428

437-
except pymongo.errors.DuplicateKeyError as err:
438-
message = "Tried to save duplicate unique keys (%s)"
439-
raise NotUniqueError(message % err)
440-
except pymongo.errors.OperationFailure as err:
441-
message = "Could not save document (%s)"
442-
if re.match("^E1100[01] duplicate key", str(err)):
443-
# E11000 - duplicate key error index
444-
# E11001 - duplicate key on update
429+
# update force_insert to reflect that we might have already run the insert for
430+
# the placeholder
431+
force_insert = force_insert and not has_placeholder_saved
432+
433+
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
434+
435+
if validate:
436+
self.validate(clean=clean)
437+
438+
doc_id = self.to_mongo(fields=[self._meta["id_field"]])
439+
created = "_id" not in doc_id or self._created or force_insert
440+
441+
signals.pre_save_post_validation.send(
442+
self.__class__, document=self, created=created, **signal_kwargs
443+
)
444+
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
445+
doc = self.to_mongo()
446+
447+
if self._meta.get("auto_create_index", True):
448+
self.ensure_indexes()
449+
450+
try:
451+
# Save a new document or update an existing one
452+
if created:
453+
object_id = self._save_create(doc, force_insert, write_concern)
454+
else:
455+
object_id, created = self._save_update(
456+
doc, save_condition, write_concern
457+
)
458+
except pymongo.errors.DuplicateKeyError as err:
445459
message = "Tried to save duplicate unique keys (%s)"
446460
raise NotUniqueError(message % err)
447-
raise OperationError(message % err)
461+
except pymongo.errors.OperationFailure as err:
462+
message = "Could not save document (%s)"
463+
if re.match("^E1100[01] duplicate key", str(err)):
464+
# E11000 - duplicate key error index
465+
# E11001 - duplicate key on update
466+
message = "Tried to save duplicate unique keys (%s)"
467+
raise NotUniqueError(message % err)
468+
raise OperationError(message % err)
469+
470+
# Make sure we store the PK on this document now that it's saved
471+
id_field = self._meta["id_field"]
472+
if created or id_field not in self._meta.get("shard_key", []):
473+
self[id_field] = self._fields[id_field].to_python(object_id)
474+
475+
signals.post_save.send(
476+
self.__class__, document=self, created=created, **signal_kwargs
477+
)
448478

449-
# Make sure we store the PK on this document now that it's saved
450-
id_field = self._meta["id_field"]
451-
if created or id_field not in self._meta.get("shard_key", []):
452-
self[id_field] = self._fields[id_field].to_python(object_id)
479+
self._clear_changed_fields()
480+
self._created = False
481+
except Exception as e:
482+
raise e
483+
finally:
484+
self._is_saving = False
453485

454-
signals.post_save.send(
455-
self.__class__, document=self, created=created, **signal_kwargs
456-
)
486+
return self
457487

458-
self._clear_changed_fields()
459-
self._created = False
488+
def _save_place_holder(self, force_insert, write_concern):
489+
"""Save a temp placeholder to the db with nothing but the ID.
490+
"""
491+
data = SON()
460492

461-
return self
493+
object_id = self._save_create(data, force_insert, write_concern)
494+
495+
id_field = self._meta["id_field"]
496+
self[id_field] = self._fields[id_field].to_python(object_id)
462497

463498
def _save_create(self, doc, force_insert, write_concern):
464499
"""Save a new document.
@@ -556,28 +591,11 @@ def cascade_save(self, **kwargs):
556591
"""Recursively save any references and generic references on the
557592
document.
558593
"""
559-
_refs = kwargs.get("_refs") or []
560-
561-
ReferenceField = _import_class("ReferenceField")
562-
GenericReferenceField = _import_class("GenericReferenceField")
563594

564595
for name, cls in self._fields.items():
565-
if not isinstance(cls, (ReferenceField, GenericReferenceField)):
566-
continue
567-
568-
ref = self._data.get(name)
569-
if not ref or isinstance(ref, DBRef):
596+
if not isinstance(cls, SaveableBaseField):
570597
continue
571-
572-
if not getattr(ref, "_changed_fields", True):
573-
continue
574-
575-
ref_id = f"{ref.__class__.__name__},{str(ref._data)}"
576-
if ref and ref_id not in _refs:
577-
_refs.append(ref_id)
578-
kwargs["_refs"] = _refs
579-
ref.save(**kwargs)
580-
ref._changed_fields = []
598+
cls.save(self, **kwargs)
581599

582600
@property
583601
def _qs(self):

mongoengine/fields.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mongoengine.base import (
2626
BaseDocument,
2727
BaseField,
28+
SaveableBaseField,
2829
ComplexBaseField,
2930
GeoJsonBaseField,
3031
LazyReference,
@@ -1123,7 +1124,7 @@ def __init__(self, field=None, *args, **kwargs):
11231124
super().__init__(field=field, *args, **kwargs)
11241125

11251126

1126-
class ReferenceField(BaseField):
1127+
class ReferenceField(SaveableBaseField):
11271128
"""A reference to a document that will be automatically dereferenced on
11281129
access (lazily).
11291130
@@ -1295,6 +1296,16 @@ def validate(self, value):
12951296
"saved to the database"
12961297
)
12971298

1299+
def save(self, instance, **kwargs):
1300+
ref = instance._data.get(self.name)
1301+
if not ref or isinstance(ref, DBRef):
1302+
return
1303+
1304+
if not getattr(self, "_changed_fields", True):
1305+
return
1306+
1307+
ref.save(**kwargs)
1308+
12981309
def lookup_member(self, member_name):
12991310
return self.document_type._fields.get(member_name)
13001311

@@ -1464,7 +1475,7 @@ def sync_all(self):
14641475
self.owner_document.objects(**filter_kwargs).update(**update_kwargs)
14651476

14661477

1467-
class GenericReferenceField(BaseField):
1478+
class GenericReferenceField(SaveableBaseField):
14681479
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
14691480
that will be automatically dereferenced on access (lazily).
14701481
@@ -1546,6 +1557,16 @@ def validate(self, value):
15461557
" saved to the database"
15471558
)
15481559

1560+
def save(self, instance, **kwargs):
1561+
ref = instance._data.get(self.name)
1562+
if not ref or isinstance(ref, DBRef):
1563+
return
1564+
1565+
if not getattr(ref, "_changed_fields", True):
1566+
return
1567+
1568+
ref.save(**kwargs)
1569+
15491570
def to_mongo(self, document):
15501571
if document is None:
15511572
return None

0 commit comments

Comments
 (0)