diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 23fa0642..1611edbe 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -526,7 +526,9 @@ def _get_model_hints( if model_field is not None: path = f"{prefix}{model_fieldname}" - if not custom_prefetches and isinstance(model_field, (models.ForeignKey, OneToOneRel)): + if not custom_prefetches and isinstance( + model_field, (models.ForeignKey, OneToOneRel) + ): # only select_related if there is no custom prefetch store.only.append(path) store.select_related.append(path) @@ -554,7 +556,11 @@ def _get_model_hints( if f_store is not None: cache.setdefault(f_model, []).append((level, f_store)) store |= f_store.with_prefix(path, info=info) - elif not custom_prefetches and GenericForeignKey and isinstance(model_field, GenericForeignKey): + elif ( + not custom_prefetches + and GenericForeignKey + and isinstance(model_field, GenericForeignKey) + ): # There's not much we can do to optimize generic foreign keys regarding # only/select_related because they can be anything. # Just prefetch_related them @@ -633,9 +639,13 @@ def _get_model_hints( if prefetch.queryset is not None: p_qs = prefetch.queryset else: - p_qs = _get_prefetch_queryset(remote_model, field, config, info) + p_qs = _get_prefetch_queryset( + remote_model, field, config, info + ) f_qs = f_store.apply(p_qs, info=info, config=config) - f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr) + f_prefetch = Prefetch( + prefetch.prefetch_through, f_qs, prefetch.to_attr + ) if prefix: f_prefetch.add_prefix(prefix) f_prefetch._optimizer_sentinel = _sentinel # type: ignore diff --git a/tests/projects/models.py b/tests/projects/models.py index 7b89accd..6ea42d88 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, Any, Optional, Annotated +from typing import TYPE_CHECKING, Annotated, Any, Optional import strawberry from django.contrib.auth import get_user_model from django.db import models -from django.db.models import Count, QuerySet, Prefetch +from django.db.models import Count, Prefetch, QuerySet from django.utils.translation import gettext_lazy as _ from django_choices_field import TextChoicesField @@ -59,17 +59,19 @@ def is_small(self) -> bool: prefetch_related=lambda _: Prefetch( "milestones", to_attr="next_milestones_prop_pf", - queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), ) ) - def next_milestones_property(self) -> list[Annotated['MilestoneType', strawberry.lazy('.schema')]]: + def next_milestones_property( + self, + ) -> list[Annotated["MilestoneType", strawberry.lazy(".schema")]]: + """The milestones for the project ordered by their due date """ - The milestones for the project ordered by their due date - """ - if hasattr(self, 'next_milestones_prop_pf'): + if hasattr(self, "next_milestones_prop_pf"): return self.next_milestones_prop_pf - else: - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.milestones.filter(due_date__isnull=False).order_by("due_date") class Milestone(models.Model): diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 7f8b44e4..ef67aac2 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -1,7 +1,7 @@ import asyncio import datetime import decimal -from typing import Iterable, List, Optional, Type, cast, Union +from typing import Iterable, List, Optional, Type, Union, cast import strawberry from django.contrib.auth import get_user_model @@ -109,17 +109,17 @@ class ProjectType(relay.Node): prefetch_related=lambda _: Prefetch( "milestones", to_attr="next_milestones_pf", - queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date") + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), ) ) def next_milestones(self) -> "list[MilestoneType]": + """The milestones for the project ordered by their due date """ - The milestones for the project ordered by their due date - """ - if hasattr(self, 'next_milestones_pf'): + if hasattr(self, "next_milestones_pf"): return self.next_milestones_pf - else: - return self.milestones.filter(due_date__isnull=False).order_by("due_date") + return self.milestones.filter(due_date__isnull=False).order_by("due_date") @strawberry_django.filter(Milestone, lookups=True) @@ -314,7 +314,9 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]): """Project connection documentation.""" -ProjectFeedItem = Annotated[Union[IssueType, MilestoneType], strawberry.union('ProjectFeedItem')] +ProjectFeedItem = Annotated[ + Union[IssueType, MilestoneType], strawberry.union("ProjectFeedItem") +] @strawberry.type