From ba9798023874327e8dac843324975f97a35dc398 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 25 Jan 2025 14:32:55 +0100 Subject: [PATCH] fix(optimizer): Avoid merging prefetches when using aliases Merging querysets is usually fine, but when using an alias it might be that one queryset is filtering for something, and the other is filtering for something else, for the same field. That can lead to wrong results being returned. From now on, if a field is specified more than once with an alias, the optimizer will skip it, possibly causing n+1 issues, but avoiding wrong results (which is worse). Fix #695 --- strawberry_django/optimizer.py | 57 ++++++++++++++++++++----- tests/test_optimizer.py | 76 +++++++++++++++++++++------------- 2 files changed, 93 insertions(+), 40 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index b56340f4..830171dc 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -5,7 +5,7 @@ import copy import dataclasses import itertools -from collections import defaultdict +from collections import Counter, defaultdict from collections.abc import Callable from typing import ( TYPE_CHECKING, @@ -58,6 +58,7 @@ from .utils.inspect import ( PrefetchInspector, get_model_field, + get_model_fields, get_possible_type_definitions, ) from .utils.typing import ( @@ -1035,19 +1036,29 @@ def _get_model_hints( if pk is not None: store.only.append(pk.attname) - for f_selections in _get_selections(info, parent_type).values(): - field_data = _get_field_data( - f_selections, - object_definition, - schema, - parent_type=parent_type, - info=info, + selections = [ + field_data + for f_selection in _get_selections(info, parent_type).values() + if ( + field_data := _get_field_data( + f_selection, + object_definition, + schema, + parent_type=parent_type, + info=info, + ) ) - if field_data is None: + is not None + ] + fields_counter = Counter(field_data[0] for field_data in selections) + + for field, f_definition, f_selection, f_info in selections: + # If a field is selected more than once in the query, that means it is being + # aliased. In this case, optimizing it would make one query to affect the other, + # resulting in wrong results for both. + if fields_counter[field] > 1: continue - field, f_definition, f_selection, f_info = field_data - # Add annotations from the field if they exist if field_store := _get_hints_from_field(field, f_info=f_info, prefix=prefix): store |= field_store @@ -1089,6 +1100,30 @@ def _get_model_hints( store.only.extend(inner_store.only) store.select_related.extend(inner_store.select_related) + # In case we skipped optimization for a relation, we might end up with a new QuerySet + # which would not select its parent relation field on `.only()`, causing n+1 issues. + # Make sure that in this case we also select it. + if level == 0 and store.only and info.path.prev: + own_fk_fields = [ + field + for field in get_model_fields(model).values() + if isinstance(field, models.ForeignKey) + ] + + path = info.path + while path := path.prev: + type_ = schema.get_type_by_name(path.typename) + if not isinstance(type_, StrawberryObjectDefinition): + continue + + if not (strawberry_django_type := get_django_definition(type_.origin)): + continue + + for field in own_fk_fields: + if field.related_model is strawberry_django_type.model: + store.only.append(field.attname) + break + return store diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index c2274250..8858f7ec 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -308,14 +308,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient): } ... milestoneFrag } - milestoneAgain: milestone { - name - project { - id - name - } - ... milestoneFrag - } } } } @@ -341,7 +333,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient): "nameWithKind": f"{i.kind}: {i.name}", "nameWithPriority": f"{i.kind}: {i.priority}", "milestone": m_res, - "milestoneAgain": m_res, }, ) @@ -538,12 +529,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient): ... milestoneFrag } } - otherIssues: issues { - id - milestone { - ... milestoneFrag - } - } } } } @@ -566,7 +551,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient): "name": p_res["name"], }, "issues": [], - "otherIssues": [], } p_res["milestones"].append(m_res) for i in IssueFactory.create_batch(3, milestone=m): @@ -585,22 +569,10 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient): }, }, ) - m_res["otherIssues"].append( - { - "id": to_base64("IssueType", i.id), - "milestone": { - "id": m_res["id"], - "project": { - "id": p_res["id"], - "name": p_res["name"], - }, - }, - }, - ) assert len(expected) == 3 for e in expected: - with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 8): + with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5): res = gql_client.query(query, {"node_id": e["id"]}) assert res.data == {"project": e} @@ -1089,6 +1061,52 @@ def test_query_nested_connection_with_filter(db, gql_client: GraphQLTestClient): } == expected +@pytest.mark.django_db(transaction=True) +def test_query_nested_connection_with_filter_and_alias( + db, gql_client: GraphQLTestClient +): + query = """ + query TestQuery ($id: GlobalID!) { + milestone(id: $id) { + id + fooIssues: issuesWithFilters (filters: {search: "Foo"}) { + edges { + node { + id + } + } + } + barIssues: issuesWithFilters (filters: {search: "Bar"}) { + edges { + node { + id + } + } + } + } + } + """ + + milestone = MilestoneFactory.create() + issue1 = IssueFactory.create(milestone=milestone, name="Foo") + issue2 = IssueFactory.create(milestone=milestone, name="Foo Bar") + issue3 = IssueFactory.create(milestone=milestone, name="Bar Foo") + issue4 = IssueFactory.create(milestone=milestone, name="Bar Bin") + + with assert_num_queries(3): + res = gql_client.query(query, {"id": to_base64("MilestoneType", milestone.pk)}) + + assert isinstance(res.data, dict) + result = res.data["milestone"] + assert isinstance(result, dict) + + foo_expected = {to_base64("IssueType", i.pk) for i in [issue1, issue2, issue3]} + assert {edge["node"]["id"] for edge in result["fooIssues"]["edges"]} == foo_expected + + bar_expected = {to_base64("IssueType", i.pk) for i in [issue2, issue3, issue4]} + assert {edge["node"]["id"] for edge in result["barIssues"]["edges"]} == bar_expected + + @pytest.mark.django_db(transaction=True) def test_query_with_optimizer_paginated_prefetch(): @strawberry_django.type(Milestone, pagination=True)