diff --git a/rest_framework-stubs/decorators.pyi b/rest_framework-stubs/decorators.pyi index cd177939d..8c1b07570 100644 --- a/rest_framework-stubs/decorators.pyi +++ b/rest_framework-stubs/decorators.pyi @@ -1,3 +1,4 @@ +import sys from collections.abc import Callable, Mapping, Sequence from typing import Any, Literal, Protocol, TypeVar @@ -17,6 +18,31 @@ _View = TypeVar("_View", bound=Callable[..., HttpResponseBase]) _P = ParamSpec("_P") _RESP = TypeVar("_RESP", bound=HttpResponseBase) +_MixedCaseHttpMethod: TypeAlias = Literal[ + "GET", + "POST", + "DELETE", + "PUT", + "PATCH", + "TRACE", + "HEAD", + "OPTIONS", + "get", + "post", + "delete", + "put", + "patch", + "trace", + "head", + "options", +] +if sys.version_info >= (3, 11): + from http import HTTPMethod + + _HttpMethod: TypeAlias = _MixedCaseHttpMethod | HTTPMethod +else: + _HttpMethod: TypeAlias = _MixedCaseHttpMethod + class MethodMapper(dict): def __init__(self, action: _View, methods: Sequence[str]) -> None: ... def _map(self, method: str, func: _View) -> _View: ... @@ -29,43 +55,8 @@ class MethodMapper(dict): def options(self, func: _View) -> _View: ... def trace(self, func: _View) -> _View: ... -_LOWER_CASE_HTTP_VERBS: TypeAlias = Sequence[ - Literal[ - "get", - "post", - "delete", - "put", - "patch", - "trace", - "head", - "options", - ] -] - -_MIXED_CASE_HTTP_VERBS: TypeAlias = Sequence[ - Literal[ - "GET", - "POST", - "DELETE", - "PUT", - "PATCH", - "TRACE", - "HEAD", - "OPTIONS", - "get", - "post", - "delete", - "put", - "patch", - "trace", - "head", - "options", - ] -] - class ViewSetAction(Protocol[_View]): detail: bool - methods: _LOWER_CASE_HTTP_VERBS url_path: str url_name: str kwargs: Mapping[str, Any] @@ -84,7 +75,7 @@ def throttle_classes(throttle_classes: Sequence[BaseThrottle | type[BaseThrottle def permission_classes(permission_classes: Sequence[_PermissionClass]) -> Callable[[_View], _View]: ... def schema(view_inspector: ViewInspector | type[ViewInspector] | None) -> Callable[[_View], _View]: ... def action( - methods: _MIXED_CASE_HTTP_VERBS | None = ..., + methods: Sequence[_HttpMethod] | None = ..., detail: bool = ..., url_path: str | None = ..., url_name: str | None = ..., diff --git a/tests/typecheck/test_decorators.yml b/tests/typecheck/test_decorators.yml index 05c58d145..666faf799 100644 --- a/tests/typecheck/test_decorators.yml +++ b/tests/typecheck/test_decorators.yml @@ -55,9 +55,25 @@ from rest_framework.response import Response class MyView(viewsets.ViewSet): - @action(methods=("get",), detail=False) def view_func_1(self, request: Request) -> Response: ... + @action(methods=["post"], detail=False) + def view_func_2(self, request: Request) -> Response: ... + @action(methods=("GET",), detail=False) + def view_func_3(self, request: Request) -> Response: ... + +- case: method_decorator_http_libary + skip: sys.version_info < (3, 11) + main: | + from http import HTTPMethod + from rest_framework import viewsets + from rest_framework.decorators import action + from rest_framework.request import Request + from rest_framework.response import Response - @action(methods=["post",], detail=False) + MY_VAR: HTTPMethod = HTTPMethod.POST + class MyView(viewsets.ViewSet): + @action(methods=[HTTPMethod.GET], detail=False) + def view_func_1(self, request: Request) -> Response: ... + @action(methods=[MY_VAR], detail=False) def view_func_2(self, request: Request) -> Response: ...