diff --git a/pyproject.toml b/pyproject.toml index 5d310550..24459122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,7 @@ exclude = [ "**/migrations/*" = ["RUF012"] [tool.ruff.lint.pylint] -max-nested-blocks = 7 +max-nested-blocks = 8 [tool.ruff.lint.isort] diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 7cafdb13..5bd817d5 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -12,6 +12,7 @@ Callable, ForwardRef, Generator, + Optional, Type, TypeVar, cast, @@ -189,6 +190,18 @@ def with_hints( ), ) + def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]: + custom_prefetches = [] + for p in self.prefetch_related: + if isinstance(p, Callable): + assert_type(p, PrefetchCallable) + p = p(info) # noqa: PLW2901 + + # to_attr is not typed in django stubs + if isinstance(p, Prefetch) and p.to_attr is not None: # type: ignore + custom_prefetches.append(p) + return custom_prefetches + def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo): prefetch_related = [] for p in self.prefetch_related: @@ -458,7 +471,8 @@ def _get_model_hints( continue # Add annotations from the field if they exist - field_store = getattr(field, "store", None) + field_store = cast(Optional[OptimizerStore], getattr(field, "store", None)) + custom_prefetches: list[Prefetch] = [] if field_store is not None: if ( len(field_store.annotate) == 1 @@ -477,20 +491,52 @@ def _get_model_hints( store |= ( field_store.with_prefix(prefix, info=info) if prefix else field_store ) + custom_prefetches = field_store.get_custom_prefetches(info) # Then from the model property if one is defined model_attr = getattr(model, field.python_name, None) if model_attr is not None and isinstance(model_attr, ModelProperty): attr_store = model_attr.store store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store + attr_store_prefetches = attr_store.get_custom_prefetches(info) + if attr_store_prefetches: + custom_prefetches.extend(attr_store_prefetches) + + model_fieldname: str | None = None + model_field = None + # try to find the model field name in any custom prefetches + if custom_prefetches: + for prefetch in custom_prefetches: + prefetch_field = model_fields.get(prefetch.prefetch_through, None) + if prefetch_field: + if not model_field: + model_field = prefetch_field + model_fieldname = prefetch.prefetch_through + elif model_field != prefetch_field: + # we found more than one model field from the custom prefetches + # not much we can do here + model_field = None + model_fieldname = None + custom_prefetches = [] + break # Lastly, from the django field itself - model_fieldname: str = getattr(field, "django_name", None) or field.python_name - model_field = model_fields.get(model_fieldname, None) + if not model_fieldname: + model_fieldname = getattr(field, "django_name", None) or field.python_name + model_field = ( + model_fields.get(model_fieldname, None) if model_fieldname else None + ) + if model_field is not None: + assert ( + model_fieldname is not None + ) # if we have a model_field, then model_fieldname must also be set path = f"{prefix}{model_fieldname}" - if isinstance(model_field, (models.ForeignKey, OneToOneRel)): + if not custom_prefetches and isinstance( + model_field, (models.ForeignKey, OneToOneRel) + ): + # only select_related if there is no custom prefetch store.only.append(path) store.select_related.append(path) @@ -517,7 +563,11 @@ def _get_model_hints( if f_store is not None: cache.setdefault(f_model, []).append((level, f_store)) store |= f_store.with_prefix(path, info=info) - elif GenericForeignKey and isinstance(model_field, GenericForeignKey): + elif ( + not custom_prefetches + and GenericForeignKey + and isinstance(model_field, GenericForeignKey) + ): # There's not much we can do to optimize generic foreign keys regarding # only/select_related because they can be anything. # Just prefetch_related them @@ -530,7 +580,8 @@ def _get_model_hints( if len(f_types) > 1: # This might be a generic foreign key. # In this case, just prefetch it - store.prefetch_related.append(model_fieldname) + if not custom_prefetches: + store.prefetch_related.append(model_fieldname) elif len(f_types) == 1: remote_field = model_field.remote_field remote_model = remote_field.model @@ -590,24 +641,45 @@ def _get_model_hints( cache.setdefault(remote_model, []).append((level, f_store)) - # If prefetch_custom_queryset is false, use _base_manager here - # instead of _default_manager because we are getting related - # objects, and not querying it directly. Else use the type's - # get_queryset and model's custom QuerySet. - base_qs = _get_prefetch_queryset( - remote_model, - field, - config, - info, - ) - f_qs = f_store.apply( - base_qs, - info=info, - config=config, - ) - f_prefetch = Prefetch(path, queryset=f_qs) - f_prefetch._optimizer_sentinel = _sentinel # type: ignore - store.prefetch_related.append(f_prefetch) + if custom_prefetches: + for prefetch in custom_prefetches: + # stubs incorrectly say that queryset is never None + if prefetch.queryset is not None: # type: ignore + p_qs = prefetch.queryset + else: + p_qs = _get_prefetch_queryset( + remote_model, field, config, info + ) + f_qs = f_store.apply(p_qs, info=info, config=config) + f_prefetch = Prefetch( + # to_attr is not typed in django stubs + prefetch.prefetch_through, + f_qs, + prefetch.to_attr, # type: ignore + ) + if prefix: + f_prefetch.add_prefix(prefix) + f_prefetch._optimizer_sentinel = _sentinel # type: ignore + store.prefetch_related.append(f_prefetch) + else: + # If prefetch_custom_queryset is false, use _base_manager here + # instead of _default_manager because we are getting related + # objects, and not querying it directly. Else use the type's + # get_queryset and model's custom QuerySet. + base_qs = _get_prefetch_queryset( + remote_model, + field, + config, + info, + ) + f_qs = f_store.apply( + base_qs, + info=info, + config=config, + ) + f_prefetch = Prefetch(path, queryset=f_qs) + f_prefetch._optimizer_sentinel = _sentinel # type: ignore + store.prefetch_related.append(f_prefetch) else: store.only.append(path) diff --git a/tests/projects/models.py b/tests/projects/models.py index 448a1ff4..37fdaf22 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional, cast +import strawberry from django.contrib.auth import get_user_model from django.db import models -from django.db.models import Count, QuerySet +from django.db.models import Count, Prefetch, QuerySet from django.utils.translation import gettext_lazy as _ from django_choices_field import TextChoicesField +from typing_extensions import Annotated from strawberry_django.descriptors import model_property from strawberry_django.utils.typing import UserType @@ -12,6 +14,8 @@ if TYPE_CHECKING: from django.db.models.manager import RelatedManager + from .schema import MilestoneType + User = get_user_model() @@ -50,10 +54,32 @@ class Status(models.TextChoices): default=None, ) + next_milestones_prop_pf: List["Milestone"] + @model_property(annotate={"_milestone_count": Count("milestone")}) def is_small(self) -> bool: return self._milestone_count < 3 # type: ignore + @model_property( + prefetch_related=lambda _: Prefetch( + "milestones", + to_attr="next_milestones_prop_pf", + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), + ) + ) + def next_milestones_property( + self, + ) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]: + """Return the milestones for the project ordered by their due date.""" + if hasattr(self, "next_milestones_prop_pf"): + return cast(List["MilestoneType"], self.next_milestones_prop_pf) + return cast( + List["MilestoneType"], + self.milestones.filter(due_date__isnull=False).order_by("due_date"), + ) + class Milestone(models.Model): issues: "RelatedManager[Issue]" @@ -78,6 +104,8 @@ class Milestone(models.Model): related_query_name="milestone", ) + next_milestones_pf: List["Milestone"] + class FavoriteQuerySet(QuerySet): def by_user(self, user: UserType): diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 6c48f979..51afc3c8 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -1,7 +1,7 @@ import asyncio import datetime import decimal -from typing import Iterable, List, Optional, Type, cast +from typing import Iterable, List, Optional, Type, Union, cast import strawberry from django.contrib.auth import get_user_model @@ -103,6 +103,23 @@ class ProjectType(relay.Node): cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()]) is_small: strawberry.auto + next_milestones_property: strawberry.auto + + @strawberry_django.field( + prefetch_related=lambda _: Prefetch( + "milestones", + to_attr="next_milestones_pf", + queryset=Milestone.objects.filter(due_date__isnull=False).order_by( + "due_date" + ), + ) + ) + def next_milestones(self, root: Milestone) -> "List[MilestoneType]": + """Return the milestones for the project ordered by their due date.""" + if hasattr(root, "next_milestones_pf"): + return cast(List[MilestoneType], root.next_milestones_pf) + return root.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore + @strawberry_django.filter(Milestone, lookups=True) class MilestoneFilter: @@ -296,6 +313,16 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]): """Project connection documentation.""" +ProjectFeedItem = Annotated[ + Union[IssueType, MilestoneType], strawberry.union("ProjectFeedItem") +] + + +@strawberry.type +class ProjectFeedConnection(relay.Connection[ProjectFeedItem]): + pass + + @strawberry.type class Query: """All available queries for this schema.""" diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index b89e76ce..90127ceb 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -471,6 +471,8 @@ type ProjectType implements Node { isDelayed: Boolean! cost: Decimal @isAuthenticated isSmall: Boolean! + nextMilestonesProperty(filters: MilestoneFilter, order: MilestoneOrder): [MilestoneType!]! + nextMilestones: [MilestoneType!]! } """An edge in a connection.""" diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 13e69b3d..85672488 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -534,6 +534,96 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient): assert res.data == {"project": e} +@pytest.mark.django_db(transaction=True) +def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery { + projectList { + id + nextMilestones { + id + name + project { + id + name + } + } + } + } + """ + + expected = [] + for p in ProjectFactory.create_batch(2): + p_res: dict[str, Any] = { + "id": to_base64("ProjectType", p.id), + "nextMilestones": [], + } + expected.append(p_res) + milestones = MilestoneFactory.create_batch(2, project=p) + milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date)) + for m in milestones: + m_res: dict[str, Any] = { + "id": to_base64("MilestoneType", m.id), + "name": m.name, + "project": { + "id": p_res["id"], + "name": p.name, + }, + } + p_res["nextMilestones"].append(m_res) + + assert len(expected) == 2 + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3): + res = gql_client.query(query) + assert res.data == {"projectList": expected} + + +@pytest.mark.django_db(transaction=True) +def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery { + projectList { + id + nextMilestonesProperty { + id + name + project { + id + name + } + } + } + } + """ + + expected = [] + for p in ProjectFactory.create_batch(2): + p_res: dict[str, Any] = { + "id": to_base64("ProjectType", p.id), + "nextMilestonesProperty": [], + } + expected.append(p_res) + milestones = MilestoneFactory.create_batch(2, project=p) + milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date)) + for m in milestones: + m_res: dict[str, Any] = { + "id": to_base64("MilestoneType", m.id), + "name": m.name, + "project": { + "id": p_res["id"], + "name": p.name, + }, + } + p_res["nextMilestonesProperty"].append(m_res) + + assert len(expected) == 2 + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3): + res = gql_client.query(query) + assert res.data == {"projectList": expected} + + @pytest.mark.django_db(transaction=True) def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient): query = """