Skip to content

Commit

Permalink
Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
diesieben07 committed Mar 3, 2024
1 parent 0b2d41d commit b2799fc
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 15 deletions.
50 changes: 49 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ django = ">=3.2"
django-choices-field = { version = ">=2.2.2", optional = true }
django-debug-toolbar = { version = ">=3.4", optional = true }
strawberry-graphql = ">=0.212.0"
pyright = "^1.1.352"

[tool.poetry.group.dev.dependencies]
Markdown = "^3.3.7"
Expand Down
15 changes: 10 additions & 5 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Type,
TypeVar,
cast,
Optional,
)

from django.db import models
Expand Down Expand Up @@ -196,7 +197,8 @@ def get_custom_prefetches(self, info: GraphQLResolveInfo) -> list[Prefetch]:
assert_type(p, PrefetchCallable)
p = p(info) # noqa: PLW2901

if isinstance(p, Prefetch) and p.to_attr is not None:
# to_attr is not typed in django stubs
if isinstance(p, Prefetch) and p.to_attr is not None: # type: ignore
custom_prefetches.append(p)
return custom_prefetches

Expand Down Expand Up @@ -469,7 +471,7 @@ def _get_model_hints(
continue

# Add annotations from the field if they exist
field_store = cast(OptimizerStore | None, getattr(field, "store", None))
field_store = cast(Optional[OptimizerStore], getattr(field, "store", None))
custom_prefetches: list[Prefetch] = []
if field_store is not None:
if (
Expand Down Expand Up @@ -521,9 +523,10 @@ def _get_model_hints(
# Lastly, from the django field itself
if not model_fieldname:
model_fieldname = getattr(field, "django_name", None) or field.python_name
model_field = model_fields.get(model_fieldname, None)
model_field = model_fields.get(model_fieldname, None) if model_fieldname else None

if model_field is not None:
assert model_fieldname is not None # if we have a model_field, then model_fieldname must also be set
path = f"{prefix}{model_fieldname}"

if not custom_prefetches and isinstance(
Expand Down Expand Up @@ -636,15 +639,17 @@ def _get_model_hints(

if custom_prefetches:
for prefetch in custom_prefetches:
if prefetch.queryset is not None:
# stubs incorrectly say that queryset is never None
if prefetch.queryset is not None: # type: ignore
p_qs = prefetch.queryset
else:
p_qs = _get_prefetch_queryset(
remote_model, field, config, info
)
f_qs = f_store.apply(p_qs, info=info, config=config)
f_prefetch = Prefetch(
prefetch.prefetch_through, f_qs, prefetch.to_attr
# to_attr is not typed in django stubs
prefetch.prefetch_through, f_qs, prefetch.to_attr # type: ignore
)
if prefix:
f_prefetch.add_prefix(prefix)
Expand Down
12 changes: 8 additions & 4 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Annotated, Any, Optional
from typing import TYPE_CHECKING, Any, List, Optional
from typing_extensions import Annotated

import strawberry
from django.contrib.auth import get_user_model
Expand All @@ -12,6 +13,7 @@

if TYPE_CHECKING:
from django.db.models.manager import RelatedManager
from .schema import MilestoneType

User = get_user_model()

Expand Down Expand Up @@ -51,6 +53,8 @@ class Status(models.TextChoices):
default=None,
)

next_milestones_prop_pf: List["Milestone"]

@model_property(annotate={"_milestone_count": Count("milestone")})
def is_small(self) -> bool:
return self._milestone_count < 3 # type: ignore
Expand All @@ -66,12 +70,12 @@ def is_small(self) -> bool:
)
def next_milestones_property(
self,
) -> list[Annotated["MilestoneType", strawberry.lazy(".schema")]]:
) -> List[Annotated["MilestoneType", strawberry.lazy(".schema")]]:
"""The milestones for the project ordered by their due date
"""
if hasattr(self, "next_milestones_prop_pf"):
return self.next_milestones_prop_pf
return self.milestones.filter(due_date__isnull=False).order_by("due_date")
return self.next_milestones_prop_pf # type: ignore
return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore


class Milestone(models.Model):
Expand Down
7 changes: 4 additions & 3 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class ProjectType(relay.Node):
is_small: strawberry.auto

next_milestones_property: strawberry.auto
next_milestones_pf: List[Milestone]

@strawberry_django.field(
prefetch_related=lambda _: Prefetch(
Expand All @@ -114,12 +115,12 @@ class ProjectType(relay.Node):
),
)
)
def next_milestones(self) -> "list[MilestoneType]":
def next_milestones(self) -> "List[MilestoneType]":
"""The milestones for the project ordered by their due date
"""
if hasattr(self, "next_milestones_pf"):
return self.next_milestones_pf
return self.milestones.filter(due_date__isnull=False).order_by("due_date")
return self.next_milestones_pf # type: ignore
return self.milestones.filter(due_date__isnull=False).order_by("due_date") # type: ignore


@strawberry_django.filter(Milestone, lookups=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_query_prefetch_with_to_attr(db, gql_client: GraphQLTestClient):
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date))
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_query_prefetch_with_to_attr_model_property(db, gql_client: GraphQLTestC
}
expected.append(p_res)
milestones = MilestoneFactory.create_batch(2, project=p)
milestones.sort(key=lambda ms: ms.due_date)
milestones.sort(key=lambda ms: cast(datetime.datetime, ms.due_date))
for m in milestones:
m_res: dict[str, Any] = {
"id": to_base64("MilestoneType", m.id),
Expand Down

0 comments on commit b2799fc

Please # to comment.