Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow http.HTTPMethod enum values in @action() decorator #512

Merged
merged 10 commits into from
Nov 19, 2023
63 changes: 27 additions & 36 deletions rest_framework-stubs/decorators.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Protocol, TypeVar

Expand All @@ -17,6 +18,31 @@ _View = TypeVar("_View", bound=Callable[..., HttpResponseBase])
_P = ParamSpec("_P")
_RESP = TypeVar("_RESP", bound=HttpResponseBase)

_MixedCaseHttpMethod: TypeAlias = Literal[
"GET",
"POST",
"DELETE",
"PUT",
"PATCH",
"TRACE",
"HEAD",
"OPTIONS",
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
if sys.version_info >= (3, 11):
from http import HTTPMethod

_HttpMethod: TypeAlias = _MixedCaseHttpMethod | HTTPMethod
else:
_HttpMethod: TypeAlias = _MixedCaseHttpMethod

class MethodMapper(dict):
def __init__(self, action: _View, methods: Sequence[str]) -> None: ...
def _map(self, method: str, func: _View) -> _View: ...
Expand All @@ -29,43 +55,8 @@ class MethodMapper(dict):
def options(self, func: _View) -> _View: ...
def trace(self, func: _View) -> _View: ...

_LOWER_CASE_HTTP_VERBS: TypeAlias = Sequence[
Literal[
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
]

_MIXED_CASE_HTTP_VERBS: TypeAlias = Sequence[
Literal[
"GET",
"POST",
"DELETE",
"PUT",
"PATCH",
"TRACE",
"HEAD",
"OPTIONS",
"get",
"post",
"delete",
"put",
"patch",
"trace",
"head",
"options",
]
]

class ViewSetAction(Protocol[_View]):
detail: bool
methods: _LOWER_CASE_HTTP_VERBS
url_path: str
url_name: str
kwargs: Mapping[str, Any]
Expand All @@ -84,7 +75,7 @@ def throttle_classes(throttle_classes: Sequence[BaseThrottle | type[BaseThrottle
def permission_classes(permission_classes: Sequence[_PermissionClass]) -> Callable[[_View], _View]: ...
def schema(view_inspector: ViewInspector | type[ViewInspector] | None) -> Callable[[_View], _View]: ...
def action(
methods: _MIXED_CASE_HTTP_VERBS | None = ...,
methods: Sequence[_HttpMethod] | None = ...,
detail: bool = ...,
url_path: str | None = ...,
url_name: str | None = ...,
Expand Down
20 changes: 18 additions & 2 deletions tests/typecheck/test_decorators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,25 @@
from rest_framework.response import Response

class MyView(viewsets.ViewSet):

@action(methods=("get",), detail=False)
def view_func_1(self, request: Request) -> Response: ...
@action(methods=["post"], detail=False)
def view_func_2(self, request: Request) -> Response: ...
@action(methods=("GET",), detail=False)
def view_func_3(self, request: Request) -> Response: ...

- case: method_decorator_http_libary
skip: sys.version_info < (3, 11)
main: |
from http import HTTPMethod
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response

@action(methods=["post",], detail=False)
MY_VAR: HTTPMethod = HTTPMethod.POST
class MyView(viewsets.ViewSet):
@action(methods=[HTTPMethod.GET], detail=False)
def view_func_1(self, request: Request) -> Response: ...
@action(methods=[MY_VAR], detail=False)
def view_func_2(self, request: Request) -> Response: ...