From 77a44d66b39dcc6d27170e6f87a31ee521f9fd07 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Mon, 5 Feb 2024 22:14:41 +0100 Subject: [PATCH 01/11] Deduce model field names from custom prefetches --- strawberry_django/optimizer.py | 97 +++++++++++++++++++++++++--------- tests/projects/schema.py | 26 ++++++++- tests/test_optimizer.py | 45 ++++++++++++++++ 3 files changed, 143 insertions(+), 25 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 7cafdb13..dc7a5a6f 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -189,6 +189,17 @@ def with_hints( ), ) + def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]: + custom_prefetches = [] + for p in self.prefetch_related: + if isinstance(p, Callable): + assert_type(p, PrefetchCallable) + p = p(info) # noqa: PLW2901 + + if isinstance(p, Prefetch) and p.queryset is not None and p.to_attr is not None: + custom_prefetches.append(p) + return custom_prefetches + def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo): prefetch_related = [] for p in self.prefetch_related: @@ -458,7 +469,8 @@ def _get_model_hints( continue # Add annotations from the field if they exist - field_store = getattr(field, "store", None) + field_store = cast(OptimizerStore | None, getattr(field, "store", None)) + custom_prefetches: list[Prefetch] = [] if field_store is not None: if ( len(field_store.annotate) == 1 @@ -477,20 +489,45 @@ def _get_model_hints( store |= ( field_store.with_prefix(prefix, info=info) if prefix else field_store ) + custom_prefetches = field_store.get_custom_prefetches(info) # Then from the model property if one is defined model_attr = getattr(model, field.python_name, None) if model_attr is not None and isinstance(model_attr, ModelProperty): attr_store = model_attr.store store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store + attr_store_prefetches = attr_store.get_custom_prefetches() + if attr_store_prefetches: + custom_prefetches.extend(attr_store_prefetches) + + model_fieldname: str | None = None + model_field = None + # try to find the model field name in any custom prefetches + if custom_prefetches: + for prefetch in custom_prefetches: + prefetch_field = model_fields.get(prefetch.prefetch_through, None) + if prefetch_field: + if not model_field: + model_field = prefetch_field + model_fieldname = prefetch.prefetch_through + elif model_field != prefetch_field: + # we found more than one model field from the custom prefetches + # not much we can do here + model_field = None + model_fieldname = None + custom_prefetches = [] + break # Lastly, from the django field itself - model_fieldname: str = getattr(field, "django_name", None) or field.python_name - model_field = model_fields.get(model_fieldname, None) + if not model_fieldname: + model_fieldname = getattr(field, "django_name", None) or field.python_name + model_field = model_fields.get(model_fieldname, None) + if model_field is not None: path = f"{prefix}{model_fieldname}" - if isinstance(model_field, (models.ForeignKey, OneToOneRel)): + if not custom_prefetches and isinstance(model_field, (models.ForeignKey, OneToOneRel)): + # only select_related if there is no custom prefetch store.only.append(path) store.select_related.append(path) @@ -517,7 +554,7 @@ def _get_model_hints( if f_store is not None: cache.setdefault(f_model, []).append((level, f_store)) store |= f_store.with_prefix(path, info=info) - elif GenericForeignKey and isinstance(model_field, GenericForeignKey): + elif not custom_prefetches and GenericForeignKey and isinstance(model_field, GenericForeignKey): # There's not much we can do to optimize generic foreign keys regarding # only/select_related because they can be anything. # Just prefetch_related them @@ -530,7 +567,8 @@ def _get_model_hints( if len(f_types) > 1: # This might be a generic foreign key. # In this case, just prefetch it - store.prefetch_related.append(model_fieldname) + if not custom_prefetches: + store.prefetch_related.append(model_fieldname) elif len(f_types) == 1: remote_field = model_field.remote_field remote_model = remote_field.model @@ -590,24 +628,35 @@ def _get_model_hints( cache.setdefault(remote_model, []).append((level, f_store)) - # If prefetch_custom_queryset is false, use _base_manager here - # instead of _default_manager because we are getting related - # objects, and not querying it directly. Else use the type's - # get_queryset and model's custom QuerySet. - base_qs = _get_prefetch_queryset( - remote_model, - field, - config, - info, - ) - f_qs = f_store.apply( - base_qs, - info=info, - config=config, - ) - f_prefetch = Prefetch(path, queryset=f_qs) - f_prefetch._optimizer_sentinel = _sentinel # type: ignore - store.prefetch_related.append(f_prefetch) + if custom_prefetches: + for prefetch in custom_prefetches: + f_qs = f_store.apply( + prefetch.queryset, info=info, config=config + ) + f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr) + if prefix: + f_prefetch.add_prefix(prefix) + f_prefetch._optimizer_sentinel = _sentinel # type: ignore + store.prefetch_related.append(f_prefetch) + else: + # If prefetch_custom_queryset is false, use _base_manager here + # instead of _default_manager because we are getting related + # objects, and not querying it directly. Else use the type's + # get_queryset and model's custom QuerySet. + base_qs = _get_prefetch_queryset( + remote_model, + field, + config, + info, + ) + f_qs = f_store.apply( + base_qs, + info=info, + config=config, + ) + f_prefetch = Prefetch(path, queryset=f_qs) + f_prefetch._optimizer_sentinel = _sentinel # type: ignore + store.prefetch_related.append(f_prefetch) else: store.only.append(path) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 6c48f979..88a7b5a2 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -1,7 +1,7 @@ import asyncio import datetime import decimal -from typing import Iterable, List, Optional, Type, cast +from typing import Iterable, List, Optional, Type, cast, Union import strawberry from django.contrib.auth import get_user_model @@ -103,6 +103,22 @@ class ProjectType(relay.Node): cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()]) is_small: strawberry.auto + @strawberry_django.field( + prefetch_related=lambda _: Prefetch( + "milestones", + to_attr="next_milestones_pf", + queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + ) + ) + def next_milestones(self) -> "list[MilestoneType]": + """ + The milestones for the project ordered by their due date + """ + if hasattr(self, 'next_milestones_pf'): + return self.next_milestones_pf + else: + return self.milestones.filter(due_date__isnull=False).order_by("due_date") + @strawberry_django.filter(Milestone, lookups=True) class MilestoneFilter: @@ -296,6 +312,14 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]): """Project connection documentation.""" +ProjectFeedItem = Annotated[Union[IssueType, MilestoneType], strawberry.union('ProjectFeedItem')] + + +@strawberry.type +class ProjectFeedConnection(relay.Connection[ProjectFeedItem]): + pass + + @strawberry.type class Query: """All available queries for this schema.""" diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 13e69b3d..65197975 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -534,6 +534,51 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient): assert res.data == {"project": e} +@pytest.mark.django_db(transaction=True) +def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery { + projectList { + id + nextMilestones { + id + name + project { + id + name + } + } + } + } + """ + + expected = [] + for p in ProjectFactory.create_batch(2): + p_res: dict[str, Any] = { + "id": to_base64("ProjectType", p.id), + "nextMilestones": [], + } + expected.append(p_res) + milestones = MilestoneFactory.create_batch(2, project=p) + milestones.sort(key=lambda ms: ms.due_date) + for m in milestones: + m_res: dict[str, Any] = { + "id": to_base64("MilestoneType", m.id), + "name": m.name, + "project": { + "id": p_res["id"], + "name": p.name, + }, + } + p_res["nextMilestones"].append(m_res) + + assert len(expected) == 2 + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3): + res = gql_client.query(query) + assert res.data == {"projectList": expected} + + @pytest.mark.django_db(transaction=True) def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient): query = """ From 0088e393a1d13103ebed6c09860b7b02aea3b8a9 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Mon, 5 Feb 2024 22:21:51 +0100 Subject: [PATCH 02/11] Add test for custom prefetch deduction on model_property --- strawberry_django/optimizer.py | 2 +- tests/projects/models.py | 21 ++++++++++++++-- tests/projects/schema.py | 2 ++ tests/test_optimizer.py | 45 ++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index dc7a5a6f..94d3b69b 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -496,7 +496,7 @@ def _get_model_hints( if model_attr is not None and isinstance(model_attr, ModelProperty): attr_store = model_attr.store store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store - attr_store_prefetches = attr_store.get_custom_prefetches() + attr_store_prefetches = attr_store.get_custom_prefetches(info) if attr_store_prefetches: custom_prefetches.extend(attr_store_prefetches) diff --git a/tests/projects/models.py b/tests/projects/models.py index 448a1ff4..7b89accd 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Annotated +import strawberry from django.contrib.auth import get_user_model from django.db import models -from django.db.models import Count, QuerySet +from django.db.models import Count, QuerySet, Prefetch from django.utils.translation import gettext_lazy as _ from django_choices_field import TextChoicesField @@ -54,6 +55,22 @@ class Status(models.TextChoices): def is_small(self) -> bool: return self._milestone_count < 3 # type: ignore + @model_property( + prefetch_related=lambda _: Prefetch( + "milestones", + to_attr="next_milestones_prop_pf", + queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + ) + ) + def next_milestones_property(self) -> list[Annotated['MilestoneType', strawberry.lazy('.schema')]]: + """ + The milestones for the project ordered by their due date + """ + if hasattr(self, 'next_milestones_prop_pf'): + return self.next_milestones_prop_pf + else: + return self.milestones.filter(due_date__isnull=False).order_by("due_date") + class Milestone(models.Model): issues: "RelatedManager[Issue]" diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 88a7b5a2..7f8b44e4 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -103,6 +103,8 @@ class ProjectType(relay.Node): cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()]) is_small: strawberry.auto + next_milestones_property: strawberry.auto + @strawberry_django.field( prefetch_related=lambda _: Prefetch( "milestones", diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 65197975..e36b1a05 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -579,6 +579,51 @@ def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient): assert res.data == {"projectList": expected} +@pytest.mark.django_db(transaction=True) +def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery { + projectList { + id + nextMilestonesProperty { + id + name + project { + id + name + } + } + } + } + """ + + expected = [] + for p in ProjectFactory.create_batch(2): + p_res: dict[str, Any] = { + "id": to_base64("ProjectType", p.id), + "nextMilestonesProperty": [], + } + expected.append(p_res) + milestones = MilestoneFactory.create_batch(2, project=p) + milestones.sort(key=lambda ms: ms.due_date) + for m in milestones: + m_res: dict[str, Any] = { + "id": to_base64("MilestoneType", m.id), + "name": m.name, + "project": { + "id": p_res["id"], + "name": p.name, + }, + } + p_res["nextMilestonesProperty"].append(m_res) + + assert len(expected) == 2 + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3): + res = gql_client.query(query) + assert res.data == {"projectList": expected} + + @pytest.mark.django_db(transaction=True) def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient): query = """ From f82d3f9c51d6b22c715eb44fabf1dbd493a97247 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Mon, 5 Feb 2024 22:37:51 +0100 Subject: [PATCH 03/11] All prefetches with to_attr are custom, even without queryset --- strawberry_django/optimizer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 94d3b69b..23fa0642 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -196,7 +196,7 @@ def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]: assert_type(p, PrefetchCallable) p = p(info) # noqa: PLW2901 - if isinstance(p, Prefetch) and p.queryset is not None and p.to_attr is not None: + if isinstance(p, Prefetch) and p.to_attr is not None: custom_prefetches.append(p) return custom_prefetches @@ -630,9 +630,11 @@ def _get_model_hints( if custom_prefetches: for prefetch in custom_prefetches: - f_qs = f_store.apply( - prefetch.queryset, info=info, config=config - ) + if prefetch.queryset is not None: + p_qs = prefetch.queryset + else: + p_qs = _get_prefetch_queryset(remote_model, field, config, info) + f_qs = f_store.apply(p_qs, info=info, config=config) f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr) if prefix: f_prefetch.add_prefix(prefix) From 0b2d41d141156ce807629ed1aec7a69068017fd8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:42:50 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry_django/optimizer.py | 18 ++++++++++++++---- tests/projects/models.py | 20 +++++++++++--------- tests/projects/schema.py | 18 ++++++++++-------- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 23fa0642..1611edbe 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -526,7 +526,9 @@ def _get_model_hints( if model_field is not None: path = f"{prefix}{model_fieldname}" - if not custom_prefetches and isinstance(model_field, (models.ForeignKey, OneToOneRel)): + if not custom_prefetches and isinstance( + model_field, (models.ForeignKey, OneToOneRel) + ): # only select_related if there is no custom prefetch store.only.append(path) store.select_related.append(path) @@ -554,7 +556,11 @@ def _get_model_hints( if f_store is not None: cache.setdefault(f_model, []).append((level, f_store)) store |= f_store.with_prefix(path, info=info) - elif not custom_prefetches and GenericForeignKey and isinstance(model_field, GenericForeignKey): + elif ( + not custom_prefetches + and GenericForeignKey + and isinstance(model_field, GenericForeignKey) + ): # There's not much we can do to optimize generic foreign keys regarding # only/select_related because they can be anything. # Just prefetch_related them @@ -633,9 +639,13 @@ def _get_model_hints( if prefetch.queryset is not None: p_qs = prefetch.queryset else: - p_qs = _get_prefetch_queryset(remote_model, field, config, info) + p_qs = _get_prefetch_queryset( + remote_model, field, config, info + ) f_qs = f_store.apply(p_qs, info=info, config=config) - f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr) + f_prefetch = Prefetch( + prefetch.prefetch_through, f_qs, prefetch.to_attr + ) if prefix: f_prefetch.add_prefix(prefix) f_prefetch._optimizer_sentinel = _sentinel # type: ignore diff --git a/tests/projects/models.py b/tests/projects/models.py index 7b89accd..6ea42d88 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, Any, Optional, Annotated +from typing import TYPE_CHECKING, Annotated, Any, Optional import strawberry from django.contrib.auth import get_user_model from django.db import models -from django.db.models import Count, QuerySet, Prefetch +from django.db.models import Count, Prefetch, QuerySet from django.utils.translation import gettext_lazy as _ from django_choices_field import TextChoicesField @@ -59,17 +59,19 @@ def is_small(self) -> bool: prefetch_related=lambda _: Prefetch( "milestones", to_attr="next_milestones_prop_pf", - queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), ) ) - def next_milestones_property(self) -> list[Annotated['MilestoneType', strawberry.lazy('.schema')]]: + def next_milestones_property( + self, + ) -> list[Annotated["MilestoneType", strawberry.lazy(".schema")]]: + """The milestones for the project ordered by their due date """ - The milestones for the project ordered by their due date - """ - if hasattr(self, 'next_milestones_prop_pf'): + if hasattr(self, "next_milestones_prop_pf"): return self.next_milestones_prop_pf - else: - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.milestones.filter(due_date__isnull=False).order_by("due_date") class Milestone(models.Model): diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 7f8b44e4..ef67aac2 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -1,7 +1,7 @@ import asyncio import datetime import decimal -from typing import Iterable, List, Optional, Type, cast, Union +from typing import Iterable, List, Optional, Type, Union, cast import strawberry from django.contrib.auth import get_user_model @@ -109,17 +109,17 @@ class ProjectType(relay.Node): prefetch_related=lambda _: Prefetch( "milestones", to_attr="next_milestones_pf", - queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), ) ) def next_milestones(self) -> "list[MilestoneType]": + """The milestones for the project ordered by their due date """ - The milestones for the project ordered by their due date - """ - if hasattr(self, 'next_milestones_pf'): + if hasattr(self, "next_milestones_pf"): return self.next_milestones_pf - else: - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.milestones.filter(due_date__isnull=False).order_by("due_date") @strawberry_django.filter(Milestone, lookups=True) @@ -314,7 +314,9 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]): """Project connection documentation.""" -ProjectFeedItem = Annotated[Union[IssueType, MilestoneType], strawberry.union('ProjectFeedItem')] +ProjectFeedItem = Annotated[ + Union[IssueType, MilestoneType], strawberry.union("ProjectFeedItem") +] @strawberry.type From 16ff87365aa62b9b17fe112704d5d42d50a0444c Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 3 Mar 2024 17:23:22 +0100 Subject: [PATCH 05/11] Fix typing issues --- strawberry_django/optimizer.py | 15 ++++++++++----- tests/projects/models.py | 12 ++++++++---- tests/projects/schema.py | 7 ++++--- tests/test_optimizer.py | 4 ++-- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 1611edbe..d3cb1894 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -15,6 +15,7 @@ Type, TypeVar, cast, + Optional, ) from django.db import models @@ -196,7 +197,8 @@ def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]: assert_type(p, PrefetchCallable) p = p(info) # noqa: PLW2901 - if isinstance(p, Prefetch) and p.to_attr is not None: + # to_attr is not typed in django stubs + if isinstance(p, Prefetch) and p.to_attr is not None: # type: ignore custom_prefetches.append(p) return custom_prefetches @@ -469,7 +471,7 @@ def _get_model_hints( continue # Add annotations from the field if they exist - field_store = cast(OptimizerStore | None, getattr(field, "store", None)) + field_store = cast(Optional[OptimizerStore], getattr(field, "store", None)) custom_prefetches: list[Prefetch] = [] if field_store is not None: if ( @@ -521,9 +523,10 @@ def _get_model_hints( # Lastly, from the django field itself if not model_fieldname: model_fieldname = getattr(field, "django_name", None) or field.python_name - model_field = model_fields.get(model_fieldname, None) + model_field = model_fields.get(model_fieldname, None) if model_fieldname else None if model_field is not None: + assert model_fieldname is not None # if we have a model_field, then model_fieldname must also be set path = f"{prefix}{model_fieldname}" if not custom_prefetches and isinstance( @@ -636,7 +639,8 @@ def _get_model_hints( if custom_prefetches: for prefetch in custom_prefetches: - if prefetch.queryset is not None: + # stubs incorrectly say that queryset is never None + if prefetch.queryset is not None: # type: ignore p_qs = prefetch.queryset else: p_qs = _get_prefetch_queryset( @@ -644,7 +648,8 @@ def _get_model_hints( ) f_qs = f_store.apply(p_qs, info=info, config=config) f_prefetch = Prefetch( - prefetch.prefetch_through, f_qs, prefetch.to_attr + # to_attr is not typed in django stubs + prefetch.prefetch_through, f_qs, prefetch.to_attr # type: ignore ) if prefix: f_prefetch.add_prefix(prefix) diff --git a/tests/projects/models.py b/tests/projects/models.py index 6ea42d88..51dfb3a0 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Annotated, Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional +from typing_extensions import Annotated import strawberry from django.contrib.auth import get_user_model @@ -12,6 +13,7 @@ if TYPE_CHECKING: from django.db.models.manager import RelatedManager + from .schema import MilestoneType User = get_user_model() @@ -51,6 +53,8 @@ class Status(models.TextChoices): default=None, ) + next_milestones_prop_pf: List["Milestone"] + @model_property(annotate={"_milestone_count": Count("milestone")}) def is_small(self) -> bool: return self._milestone_count < 3 # type: ignore @@ -66,12 +70,12 @@ def is_small(self) -> bool: ) def next_milestones_property( self, - ) -> list[Annotated["MilestoneType", strawberry.lazy(".schema")]]: + ) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]: """The milestones for the project ordered by their due date """ if hasattr(self, "next_milestones_prop_pf"): - return self.next_milestones_prop_pf - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.next_milestones_prop_pf # type: ignore + return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore class Milestone(models.Model): diff --git a/tests/projects/schema.py b/tests/projects/schema.py index ef67aac2..89ec693f 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -104,6 +104,7 @@ class ProjectType(relay.Node): is_small: strawberry.auto next_milestones_property: strawberry.auto + next_milestones_pf: List[Milestone] @strawberry_django.field( prefetch_related=lambda _: Prefetch( @@ -114,12 +115,12 @@ class ProjectType(relay.Node): ), ) ) - def next_milestones(self) -> "list[MilestoneType]": + def next_milestones(self) -> "List[MilestoneType]": """The milestones for the project ordered by their due date """ if hasattr(self, "next_milestones_pf"): - return self.next_milestones_pf - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.next_milestones_pf # type: ignore + return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore @strawberry_django.filter(Milestone, lookups=True) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index e36b1a05..85672488 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -560,7 +560,7 @@ def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient): } expected.append(p_res) milestones = MilestoneFactory.create_batch(2, project=p) - milestones.sort(key=lambda ms: ms.due_date) + milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date)) for m in milestones: m_res: dict[str, Any] = { "id": to_base64("MilestoneType", m.id), @@ -605,7 +605,7 @@ def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestC } expected.append(p_res) milestones = MilestoneFactory.create_batch(2, project=p) - milestones.sort(key=lambda ms: ms.due_date) + milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date)) for m in milestones: m_res: dict[str, Any] = { "id": to_base64("MilestoneType", m.id), From 6d4889db23226269670b390fd06aea8e777b0d4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Mar 2024 16:26:33 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry_django/optimizer.py | 14 ++++++++++---- tests/projects/models.py | 6 +++--- tests/projects/schema.py | 3 +-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index d3cb1894..5bd817d5 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -12,10 +12,10 @@ Callable, ForwardRef, Generator, + Optional, Type, TypeVar, cast, - Optional, ) from django.db import models @@ -523,10 +523,14 @@ def _get_model_hints( # Lastly, from the django field itself if not model_fieldname: model_fieldname = getattr(field, "django_name", None) or field.python_name - model_field = model_fields.get(model_fieldname, None) if model_fieldname else None + model_field = ( + model_fields.get(model_fieldname, None) if model_fieldname else None + ) if model_field is not None: - assert model_fieldname is not None # if we have a model_field, then model_fieldname must also be set + assert ( + model_fieldname is not None + ) # if we have a model_field, then model_fieldname must also be set path = f"{prefix}{model_fieldname}" if not custom_prefetches and isinstance( @@ -649,7 +653,9 @@ def _get_model_hints( f_qs = f_store.apply(p_qs, info=info, config=config) f_prefetch = Prefetch( # to_attr is not typed in django stubs - prefetch.prefetch_through, f_qs, prefetch.to_attr # type: ignore + prefetch.prefetch_through, + f_qs, + prefetch.to_attr, # type: ignore ) if prefix: f_prefetch.add_prefix(prefix) diff --git a/tests/projects/models.py b/tests/projects/models.py index 51dfb3a0..09cd1694 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,5 +1,4 @@ from typing import TYPE_CHECKING, Any, List, Optional -from typing_extensions import Annotated import strawberry from django.contrib.auth import get_user_model @@ -7,12 +6,14 @@ from django.db.models import Count, Prefetch, QuerySet from django.utils.translation import gettext_lazy as _ from django_choices_field import TextChoicesField +from typing_extensions import Annotated from strawberry_django.descriptors import model_property from strawberry_django.utils.typing import UserType if TYPE_CHECKING: from django.db.models.manager import RelatedManager + from .schema import MilestoneType User = get_user_model() @@ -71,8 +72,7 @@ def is_small(self) -> bool: def next_milestones_property( self, ) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]: - """The milestones for the project ordered by their due date - """ + """The milestones for the project ordered by their due date""" if hasattr(self, "next_milestones_prop_pf"): return self.next_milestones_prop_pf # type: ignore return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 89ec693f..aed342b4 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -116,8 +116,7 @@ class ProjectType(relay.Node): ) ) def next_milestones(self) -> "List[MilestoneType]": - """The milestones for the project ordered by their due date - """ + """The milestones for the project ordered by their due date""" if hasattr(self, "next_milestones_pf"): return self.next_milestones_pf # type: ignore return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore From d46a2d3e63318514e30d168a0a1d984aa36787da Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 3 Mar 2024 17:32:54 +0100 Subject: [PATCH 07/11] Update graphql snapshot test schema --- tests/projects/schema.py | 1 - tests/projects/snapshots/schema.gql | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index aed342b4..034b40b7 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -104,7 +104,6 @@ class ProjectType(relay.Node): is_small: strawberry.auto next_milestones_property: strawberry.auto - next_milestones_pf: List[Milestone] @strawberry_django.field( prefetch_related=lambda _: Prefetch( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index b89e76ce..90127ceb 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -471,6 +471,8 @@ type ProjectType implements Node { isDelayed: Boolean! cost: Decimal @isAuthenticated isSmall: Boolean! + nextMilestonesProperty(filters: MilestoneFilter, order: MilestoneOrder): [MilestoneType!]! + nextMilestones: [MilestoneType!]! } """An edge in a connection.""" From 6f64194ba4e621e3892a60aa7527f9505252fd2b Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 3 Mar 2024 17:43:30 +0100 Subject: [PATCH 08/11] Improve schema typing again --- tests/projects/models.py | 8 +++++--- tests/projects/schema.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/projects/models.py b/tests/projects/models.py index 09cd1694..05816209 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, cast import strawberry from django.contrib.auth import get_user_model @@ -74,8 +74,8 @@ def next_milestones_property( ) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]: """The milestones for the project ordered by their due date""" if hasattr(self, "next_milestones_prop_pf"): - return self.next_milestones_prop_pf # type: ignore - return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore + return cast(List["MilestoneType"], self.next_milestones_prop_pf) + return cast(List["MilestoneType"], self.milestones.filter(due_date__isnull=False).order_by("due_date")) class Milestone(models.Model): @@ -101,6 +101,8 @@ class Milestone(models.Model): related_query_name="milestone", ) + next_milestones_pf: List["Milestone"] + class FavoriteQuerySet(QuerySet): def by_user(self, user: UserType): diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 034b40b7..e5bd183d 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -114,11 +114,11 @@ class ProjectType(relay.Node): ), ) ) - def next_milestones(self) -> "List[MilestoneType]": + def next_milestones(self, root: Milestone) -> "List[MilestoneType]": """The milestones for the project ordered by their due date""" - if hasattr(self, "next_milestones_pf"): - return self.next_milestones_pf # type: ignore - return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore + if hasattr(root, "next_milestones_pf"): + return cast(List[MilestoneType], root.next_milestones_pf) + return root.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore @strawberry_django.filter(Milestone, lookups=True) From e25be2dddb987669cbb01e0767bc4f8b651d9e13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Mar 2024 16:43:40 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/projects/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/projects/models.py b/tests/projects/models.py index 05816209..85704c22 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -75,7 +75,10 @@ def next_milestones_property( """The milestones for the project ordered by their due date""" if hasattr(self, "next_milestones_prop_pf"): return cast(List["MilestoneType"], self.next_milestones_prop_pf) - return cast(List["MilestoneType"], self.milestones.filter(due_date__isnull=False).order_by("due_date")) + return cast( + List["MilestoneType"], + self.milestones.filter(due_date__isnull=False).order_by("due_date"), + ) class Milestone(models.Model): From 92cdc1241e317d322fdbd32494dba742c813f07c Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 3 Mar 2024 17:59:50 +0100 Subject: [PATCH 10/11] Bump "max-nested-blocks" rule to 8 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5d310550..24459122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,7 @@ exclude = [ "**/migrations/*" = ["RUF012"] [tool.ruff.lint.pylint] -max-nested-blocks = 7 +max-nested-blocks = 8 [tool.ruff.lint.isort] From 746ec95bf279b5b8038de1768c00a3b3826c0614 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 3 Mar 2024 18:03:03 +0100 Subject: [PATCH 11/11] Fix ruff errors --- tests/projects/models.py | 2 +- tests/projects/schema.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/projects/models.py b/tests/projects/models.py index 85704c22..37fdaf22 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -72,7 +72,7 @@ def is_small(self) -> bool: def next_milestones_property( self, ) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]: - """The milestones for the project ordered by their due date""" + """Return the milestones for the project ordered by their due date.""" if hasattr(self, "next_milestones_prop_pf"): return cast(List["MilestoneType"], self.next_milestones_prop_pf) return cast( diff --git a/tests/projects/schema.py b/tests/projects/schema.py index e5bd183d..51afc3c8 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -115,7 +115,7 @@ class ProjectType(relay.Node): ) ) def next_milestones(self, root: Milestone) -> "List[MilestoneType]": - """The milestones for the project ordered by their due date""" + """Return the milestones for the project ordered by their due date.""" if hasattr(root, "next_milestones_pf"): return cast(List[MilestoneType], root.next_milestones_pf) return root.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore