Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Deduce model field names from custom prefetches #473

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
109 changes: 85 additions & 24 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@
),
)

def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]:
custom_prefetches = []
for p in self.prefetch_related:
if isinstance(p, Callable):
assert_type(p, PrefetchCallable)
p = p(info) # noqa: PLW2901
bellini666 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(p, Prefetch) and p.to_attr is not None:

Check failure on line 199 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Cannot access member "to_attr" for type "Prefetch[Any]"   Member "to_attr" is unknown (reportAttributeAccessIssue)
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches.append(p)
return custom_prefetches

def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
prefetch_related = []
for p in self.prefetch_related:
Expand Down Expand Up @@ -458,7 +469,8 @@
continue

# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
field_store = cast(OptimizerStore | None, getattr(field, "store", None))

Check failure on line 472 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Alternative syntax for unions requires Python 3.10 or newer (reportGeneralTypeIssues)
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
custom_prefetches: list[Prefetch] = []
if field_store is not None:
if (
len(field_store.annotate) == 1
Expand All @@ -477,20 +489,47 @@
store |= (
field_store.with_prefix(prefix, info=info) if prefix else field_store
)
custom_prefetches = field_store.get_custom_prefetches(info)

# Then from the model property if one is defined
model_attr = getattr(model, field.python_name, None)
if model_attr is not None and isinstance(model_attr, ModelProperty):
attr_store = model_attr.store
store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store
attr_store_prefetches = attr_store.get_custom_prefetches(info)
if attr_store_prefetches:
custom_prefetches.extend(attr_store_prefetches)

model_fieldname: str | None = None
model_field = None
# try to find the model field name in any custom prefetches
if custom_prefetches:
for prefetch in custom_prefetches:
prefetch_field = model_fields.get(prefetch.prefetch_through, None)
if prefetch_field:
if not model_field:
model_field = prefetch_field
model_fieldname = prefetch.prefetch_through
elif model_field != prefetch_field:
# we found more than one model field from the custom prefetches
# not much we can do here
model_field = None
model_fieldname = None
custom_prefetches = []
break

# Lastly, from the django field itself
model_fieldname: str = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)
if not model_fieldname:
model_fieldname = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)

Check failure on line 524 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

No overloads for "get" match the provided arguments (reportCallIssue)

Check failure on line 524 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "str | None" cannot be assigned to parameter "__key" of type "str" in function "get"   Type "str | None" cannot be assigned to type "str"     "None" is incompatible with "str" (reportArgumentType)

if model_field is not None:
path = f"{prefix}{model_fieldname}"

if isinstance(model_field, (models.ForeignKey, OneToOneRel)):
if not custom_prefetches and isinstance(
model_field, (models.ForeignKey, OneToOneRel)
):
# only select_related if there is no custom prefetch
store.only.append(path)
store.select_related.append(path)

Expand All @@ -517,11 +556,15 @@
if f_store is not None:
cache.setdefault(f_model, []).append((level, f_store))
store |= f_store.with_prefix(path, info=info)
elif GenericForeignKey and isinstance(model_field, GenericForeignKey):
elif (
not custom_prefetches
and GenericForeignKey
and isinstance(model_field, GenericForeignKey)
):
# There's not much we can do to optimize generic foreign keys regarding
# only/select_related because they can be anything.
# Just prefetch_related them
store.prefetch_related.append(model_fieldname)

Check failure on line 567 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "str | None" cannot be assigned to parameter "__object" of type "PrefetchType" in function "append" (reportArgumentType)
elif isinstance(
model_field,
_relation_fields,
Expand All @@ -530,7 +573,8 @@
if len(f_types) > 1:
# This might be a generic foreign key.
# In this case, just prefetch it
store.prefetch_related.append(model_fieldname)
if not custom_prefetches:
store.prefetch_related.append(model_fieldname)

Check failure on line 577 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "str | None" cannot be assigned to parameter "__object" of type "PrefetchType" in function "append" (reportArgumentType)
elif len(f_types) == 1:
remote_field = model_field.remote_field
remote_model = remote_field.model
Expand Down Expand Up @@ -590,24 +634,41 @@

cache.setdefault(remote_model, []).append((level, f_store))

# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
if custom_prefetches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): This block introduces a significant change in how prefetches are handled based on the presence of custom_prefetches. It's a complex addition that could benefit from a bit more inline documentation to explain the rationale behind this approach, especially for future maintainers.

for prefetch in custom_prefetches:
if prefetch.queryset is not None:

Check warning on line 639 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Condition will always evaluate to True since the types "QuerySet[Unknown]" and "None" have no overlap (reportUnnecessaryComparison)
p_qs = prefetch.queryset
else:
p_qs = _get_prefetch_queryset(
remote_model, field, config, info
)
f_qs = f_store.apply(p_qs, info=info, config=config)
f_prefetch = Prefetch(
prefetch.prefetch_through, f_qs, prefetch.to_attr

Check failure on line 647 in strawberry_django/optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Cannot access member "to_attr" for type "Prefetch[Unknown]"   Member "to_attr" is unknown (reportAttributeAccessIssue)
)
if prefix:
f_prefetch.add_prefix(prefix)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
# If prefetch_custom_queryset is false, use _base_manager here
# instead of _default_manager because we are getting related
# objects, and not querying it directly. Else use the type's
# get_queryset and model's custom QuerySet.
base_qs = _get_prefetch_queryset(
remote_model,
field,
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
else:
store.only.append(path)

Expand Down
23 changes: 21 additions & 2 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Annotated, Any, Optional

Check failure on line 1 in tests/projects/models.py

View workflow job for this annotation

GitHub Actions / Typing

"Annotated" is unknown import symbol (reportAttributeAccessIssue)

import strawberry
from django.contrib.auth import get_user_model
from django.db import models
from django.db.models import Count, QuerySet
from django.db.models import Count, Prefetch, QuerySet
from django.utils.translation import gettext_lazy as _
from django_choices_field import TextChoicesField

Expand Down Expand Up @@ -54,6 +55,24 @@
def is_small(self) -> bool:
return self._milestone_count < 3 # type: ignore

@model_property(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_prop_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by(
"due_date"
),
)
)
def next_milestones_property(
self,
) -> list[Annotated["MilestoneType", strawberry.lazy(".schema")]]:
"""The milestones for the project ordered by their due date
"""
if hasattr(self, "next_milestones_prop_pf"):
return self.next_milestones_prop_pf
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


class Milestone(models.Model):
issues: "RelatedManager[Issue]"
Expand Down
30 changes: 29 additions & 1 deletion tests/projects/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import datetime
import decimal
from typing import Iterable, List, Optional, Type, cast
from typing import Iterable, List, Optional, Type, Union, cast

import strawberry
from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -103,6 +103,24 @@ class ProjectType(relay.Node):
cost: strawberry.auto = strawberry_django.field(extensions=[IsAuthenticated()])
is_small: strawberry.auto

next_milestones_property: strawberry.auto

@strawberry_django.field(
prefetch_related=lambda _: Prefetch(
"milestones",
to_attr="next_milestones_pf",
queryset=Milestone.objects.filter(due_date__isnull=False).order_by(
"due_date"
),
)
)
def next_milestones(self) -> "list[MilestoneType]":
"""The milestones for the project ordered by their due date
"""
if hasattr(self, "next_milestones_pf"):
return self.next_milestones_pf
return self.milestones.filter(due_date__isnull=False).order_by("due_date")


@strawberry_django.filter(Milestone, lookups=True)
class MilestoneFilter:
Expand Down Expand Up @@ -296,6 +314,16 @@ class ProjectConnection(ListConnectionWithTotalCount[ProjectType]):
"""Project connection documentation."""


ProjectFeedItem = Annotated[
Union[IssueType, MilestoneType], strawberry.union("ProjectFeedItem")
]


@strawberry.type
class ProjectFeedConnection(relay.Connection[ProjectFeedItem]):
pass
Comment on lines +316 to +323
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a left over from a test attempt? =P

This is not being used, so it should be removed



@strawberry.type
class Query:
"""All available queries for this schema."""
Expand Down
90 changes: 90 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,96 @@
assert res.data == {"project": e}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestones {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestones": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)

Check failure on line 563 in tests/test_optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "(ms: Milestone) -> (date | None)" cannot be assigned to parameter "key" of type "(Milestone) -> SupportsRichComparison" in function "sort"   Type "(ms: Milestone) -> (date | None)" cannot be assigned to type "(Milestone) -> SupportsRichComparison"     Function return type "date | None" is incompatible with type "SupportsRichComparison" (reportArgumentType)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestones"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestClient):
query = """
query TestQuery {
projectList {
id
nextMilestonesProperty {
id
name
project {
id
name
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
p_res: dict[str, Any] = {
"id": to_base64("ProjectType", p.id),
"nextMilestonesProperty": [],
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)

Check failure on line 608 in tests/test_optimizer.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "(ms: Milestone) -> (date | None)" cannot be assigned to parameter "key" of type "(Milestone) -> SupportsRichComparison" in function "sort"   Type "(ms: Milestone) -> (date | None)" cannot be assigned to type "(Milestone) -> SupportsRichComparison"     Function return type "date | None" is incompatible with type "SupportsRichComparison" (reportArgumentType)
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": p_res["id"],
"name": p.name,
},
}
p_res["nextMilestonesProperty"].append(m_res)

assert len(expected) == 2

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3):
res = gql_client.query(query)
assert res.data == {"projectList": expected}


@pytest.mark.django_db(transaction=True)
def test_query_connection_with_resolver(db, gql_client: GraphQLTestClient):
query = """
Expand Down
Loading