From 01838b81c175f96be3dc5ed2b6905a410b337a05 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 27 May 2025 18:20:47 +0200 Subject: [PATCH] fix GenericReference iterable query (i.e. ``__in``) This change adds the ``_ref`` or ``_ref.$id`` prefix to a query if all values in an iterable query (i.e. ``__in``) are ``ObjectId``s or ``DBRef``s and raises an error for a mixed query which will only work for documents. These could possibly be compiled into an ``{$or: ...}`` query, but the automatic expansion can be added as necessary. --- mongoengine/queryset/transform.py | 20 ++++++++++++++++++-- tests/queryset/test_transform.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 195bc0b0b..701ca649b 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -129,6 +129,7 @@ def query(_doc_cls=None, **kwargs): singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"] singular_ops += STRING_OPERATORS + is_iterable = False if op in singular_ops: value = field.prepare_query_value(op, value) @@ -136,6 +137,7 @@ def query(_doc_cls=None, **kwargs): value = value["_id"] elif op in ("in", "nin", "all", "near") and not isinstance(value, dict): + is_iterable = True # Raise an error if the in/nin/all/near param is not iterable. value = _prepare_query_for_iterable(field, op, value) @@ -144,10 +146,24 @@ def query(_doc_cls=None, **kwargs): # * If the value is a DBRef, the key should be "field_name._ref". # * If the value is an ObjectId, the key should be "field_name._ref.$id". if isinstance(field, GenericReferenceField): - if isinstance(value, DBRef): + if isinstance(value, DBRef) or ( + is_iterable and all(isinstance(v, DBRef) for v in value) + ): parts[-1] += "._ref" - elif isinstance(value, ObjectId): + elif isinstance(value, ObjectId) or ( + is_iterable and all(isinstance(v, ObjectId) for v in value) + ): parts[-1] += "._ref.$id" + elif ( + is_iterable + and any(isinstance(v, DBRef) for v in value) + and any(isinstance(v, ObjectId) for v in value) + ): + raise ValueError( + "The `in`, `nin`, `all`, or `near`-operators cannot " + "be applied to mixed queries of DBRef/ObjectId/%s" + % _doc_cls.__name__ + ) # if op and op not in COMPARISON_OPERATORS: if op: diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index 8704187b8..8cb8ad426 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -396,6 +396,34 @@ class Shop(Document): Shop.drop_collection() + def test_transform_generic_reference_field(self): + class Object(Document): + field = GenericReferenceField() + + Object.drop_collection() + objects = Object.objects.insert([Object() for _ in range(8)]) + # singular queries + assert transform.query(Object, field=objects[0].pk) == { + "field._ref.$id": objects[0].pk + } + assert transform.query(Object, field=objects[1].to_dbref()) == { + "field._ref": objects[1].to_dbref() + } + + # iterable queries + assert transform.query(Object, field__in=[objects[2].pk, objects[3].pk]) == { + "field._ref.$id": {"$in": [objects[2].pk, objects[3].pk]} + } + assert transform.query( + Object, field__in=[objects[4].to_dbref(), objects[5].to_dbref()] + ) == {"field._ref": {"$in": [objects[4].to_dbref(), objects[5].to_dbref()]}} + + # invalid query + with pytest.raises(match="cannot be applied to mixed queries"): + transform.query(Object, field__in=[objects[6].pk, objects[7].to_dbref()]) + + Object.drop_collection() + if __name__ == "__main__": unittest.main()