Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Deduce model field names from custom prefetches #473

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
99 changes: 75 additions & 24 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ def with_hints(
),
)

def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]:
custom_prefetches = []
for p in self.prefetch_related:
if isinstance(p, Callable):
assert_type(p, PrefetchCallable)
p = p(info) # noqa: PLW2901
bellini666 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(p, Prefetch) and p.to_attr is not None:
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches.append(p)
return custom_prefetches

def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
prefetch_related = []
for p in self.prefetch_related:
Expand Down Expand Up @@ -458,7 +469,8 @@ def _get_model_hints(
continue

# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
field_store = cast(OptimizerStore | None, getattr(field, "store", None))
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches: list[Prefetch] = []
if field_store is not None:
if (
len(field_store.annotate) == 1
Expand All @@ -477,20 +489,45 @@ def _get_model_hints(
store |= (
field_store.with_prefix(prefix, info=info) if prefix else field_store
)
custom_prefetches = field_store.get_custom_prefetches(info)

# Then from the model property if one is defined
model_attr = getattr(model, field.python_name, None)
if model_attr is not None and isinstance(model_attr, ModelProperty):
attr_store = model_attr.store
store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store
attr_store_prefetches = attr_store.get_custom_prefetches(info)
if attr_store_prefetches:
custom_prefetches.extend(attr_store_prefetches)

model_fieldname: str | None = None
model_field = None
# try to find the model field name in any custom prefetches
if custom_prefetches:
for prefetch in custom_prefetches:
prefetch_field = model_fields.get(prefetch.prefetch_through, None)
if prefetch_field:
if not model_field:
model_field = prefetch_field
model_fieldname = prefetch.prefetch_through
elif model_field != prefetch_field:
# we found more than one model field from the custom prefetches
# not much we can do here
model_field = None
model_fieldname = None
custom_prefetches = []
break

# Lastly, from the django field itself
model_fieldname: str = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)
if not model_fieldname:
model_fieldname = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)

if model_field is not None:
path = f"{prefix}{model_fieldname}"

if isinstance(model_field, (models.ForeignKey, OneToOneRel)):
if not custom_prefetches and isinstance(model_field, (models.ForeignKey, OneToOneRel)):
# only select_related if there is no custom prefetch
store.only.append(path)
store.select_related.append(path)

Expand All @@ -517,7 +554,7 @@ def _get_model_hints(
if f_store is not None:
cache.setdefault(f_model, []).append((level, f_store))
store |= f_store.with_prefix(path, info=info)
elif GenericForeignKey and isinstance(model_field, GenericForeignKey):
elif not custom_prefetches and GenericForeignKey and isinstance(model_field, GenericForeignKey):
# There's not much we can do to optimize generic foreign keys regarding
# only/select_related because they can be anything.
# Just prefetch_related them
Expand All @@ -530,7 +567,8 @@ def _get_model_hints(
if len(f_types) > 1:
# This might be a generic foreign key.
# In this case, just prefetch it
store.prefetch_related.append(model_fieldname)
if not custom_prefetches:
store.prefetch_related.append(model_fieldname)
elif len(f_types) == 1:
remote_field = model_field.remote_field
remote_model = remote_field.model
Expand Down Expand Up @@ -590,24 +628,37 @@ def _get_model_hints(

cache.setdefault(remote_model, []).append((level, f_store))

# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
if custom_prefetches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): This block introduces a significant change in how prefetches are handled based on the presence of custom_prefetches. It's a complex addition that could benefit from a bit more inline documentation to explain the rationale behind this approach, especially for future maintainers.

for prefetch in custom_prefetches:
if prefetch.queryset is not None:
p_qs = prefetch.queryset
else:
p_qs = _get_prefetch_queryset(remote_model, field, config, info)
f_qs = f_store.apply(p_qs, info=info, config=config)
f_prefetch = Prefetch(prefetch.prefetch_through, f_qs, prefetch.to_attr)
if prefix:
f_prefetch.add_prefix(prefix)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
store.only.append(path)

Expand Down
21 changes: 19 additions & 2 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Annotated

import strawberry
from django.contrib.auth import get_user_model
from django.db import models
from django.db.models import Count, QuerySet
from django.db.models import Count, QuerySet, Prefetch
from django.utils.translation import gettext_lazy as _
from django_choices_field import TextChoicesField

Expand Down Expand Up @@ -54,6 +55,22 @@ class Status(models.TextChoices):
def is_small(self) -> bool:
return self._milestone_count < 3 # type: ignore

@model_property(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_prop_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date")
)
)
def next_milestones_property(self) -> list[Annotated['MilestoneType', strawberry.lazy('.schema')]]:
"""
The milestones for the project ordered by their due date
"""
if hasattr(self, 'next_milestones_prop_pf'):
return self.next_milestones_prop_pf
else:
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


class Milestone(models.Model):
issues: "RelatedManager[Issue]"
Expand Down
28 changes: 27 additions & 1 deletion tests/projects/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import datetime
import decimal
from typing import Iterable, List, Optional, Type, cast
from typing import Iterable, List, Optional, Type, cast, Union

import strawberry
from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -103,6 +103,24 @@ class ProjectType(relay.Node):
cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()])
is_small: strawberry.auto

next_milestones_property: strawberry.auto

@strawberry_django.field(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by("due_date")
)
)
def next_milestones(self) -> "list[MilestoneType]":
"""
The milestones for the project ordered by their due date
"""
if hasattr(self, 'next_milestones_pf'):
return self.next_milestones_pf
else:
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


@strawberry_django.filter(Milestone, lookups=True)
class MilestoneFilter:
Expand Down Expand Up @@ -296,6 +314,14 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]):
"""Project connection documentation."""


ProjectFeedItem = Annotated[Union[IssueType, MilestoneType], strawberry.union('ProjectFeedItem')]


@strawberry.type
class ProjectFeedConnection(relay.Connection[ProjectFeedItem]):
pass


@strawberry.type
class Query:
"""All available queries for this schema."""
Expand Down
90 changes: 90 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,96 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
assert res.data == {"project": e}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestones {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestones": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestones"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestonesProperty {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestonesProperty": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestonesProperty"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient):
query = """
Expand Down