Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Improve handling of status codes. #573

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/drf_yasg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ Parameter Location
- :py:data:`~drf_yasg.openapi.IN_QUERY` is called :py:attr:`~drf_spectacular.utils.OpenApiParameter.QUERY`
- :py:data:`~drf_yasg.openapi.IN_HEADER` is called :py:attr:`~drf_spectacular.utils.OpenApiParameter.HEADER`
- :py:data:`~drf_yasg.openapi.IN_BODY` and :py:data:`~drf_yasg.openapi.IN_FORM` have no direct equivalent.
Instead you can use ``@extend_schema(request={"<media-type>": ...})`` or
``@extend_schema(request={("<status-code>", "<media-type"): ...})``.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We both missed that this didn't make sense for requests! 🤦🏻

Instead you can use ``@extend_schema(request={"<media-type>": ...})``.
- :py:attr:`~drf_spectacular.utils.OpenApiParameter.COOKIE` is also available.

Docstring Parsing
Expand Down
44 changes: 27 additions & 17 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rest_framework.schemas.inspectors import ViewInspector
from rest_framework.schemas.utils import get_pk_description # type: ignore
from rest_framework.settings import api_settings
from rest_framework.status import is_success
from rest_framework.utils.model_meta import get_field_info
from rest_framework.views import APIView

Expand Down Expand Up @@ -1010,7 +1011,7 @@ def get_examples(self):
""" override this for custom behaviour """
return []

def _get_examples(self, serializer, direction, media_type, status_code=None, extras=None):
def _get_examples(self, serializer, direction, media_type, status_code: typing.Optional[int] = None, extras=None):
""" Handles examples for request/response. purposefully ignores parameter examples """

# don't let the parameter examples influence the serializer example retrieval
Expand All @@ -1033,7 +1034,7 @@ def _get_examples(self, serializer, direction, media_type, status_code=None, ext
continue
if media_type and media_type != example.media_type:
continue
if status_code and status_code not in example.status_codes:
if status_code and status_code not in (example.status_codes or [200, 201]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By postponing the default values here, we allow OpenApiExample to be unshackled such that it can be defined globally and reused for different status codes in multiple operations.

continue
filtered_examples.append(example)

Expand Down Expand Up @@ -1127,22 +1128,23 @@ def _get_response_bodies(self):
if self.method == 'DELETE':
return {'204': {'description': _('No response body')}}
if self._is_create_operation():
return {'201': self._get_response_for_code(response_serializers, '201')}
return {'200': self._get_response_for_code(response_serializers, '200')}
return {'201': self._get_response_for_code(response_serializers, 201)}
return {'200': self._get_response_for_code(response_serializers, 200)}
elif isinstance(response_serializers, dict):
# custom handling for overriding default return codes with @extend_schema
responses = {}
for code, serializer in response_serializers.items():
if isinstance(code, tuple):
code, media_types = str(code[0]), code[1:]
for status_code, serializer in response_serializers.items():
if isinstance(status_code, tuple):
status_code, *media_types = status_code
else:
code, media_types = str(code), None
content_response = self._get_response_for_code(serializer, code, media_types)
if code in responses:
responses[code]['content'].update(content_response['content'])
media_types = None
status_code = int(status_code)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting to int allows for http.HTTPStatus to work.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately it would be a regression

content_response = self._get_response_for_code(serializer, status_code, media_types)
if status_code in responses:
responses[status_code]['content'].update(content_response['content'])
else:
responses[code] = content_response
return responses
responses[status_code] = content_response
return {str(k): v for k, v in responses.items()}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be removed if we wanted to pass back int here too, but that would require updating all of the tests which currently have things like ...['responses']['200'] to have ...['responses'][200]. You could choose to do this - it's not difficult - but I didn't want to introduce too much noise for the purposes of this PR.

else:
warn(
f'could not resolve "{response_serializers}" for {self.method} {self.path}. '
Expand All @@ -1151,7 +1153,7 @@ def _get_response_bodies(self):
)
schema = build_basic_type(OpenApiTypes.OBJECT)
schema['description'] = _('Unspecified response body')
return {'200': self._get_response_for_code(schema, '200')}
return {'200': self._get_response_for_code(schema, 200)}

def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]:
if is_field(serializer):
Expand All @@ -1170,6 +1172,14 @@ def _get_response_for_code(self, serializer, status_code, media_types=None):
serializer, description, examples = (
serializer.response, serializer.description, serializer.examples
)
for example in examples:
if example.status_codes is None:
example.status_codes = [status_code]
Comment on lines +1176 to +1177
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main concern here is isolation as we're modifying the example. We might need to clone the object or otherwise pass through a special flag to _get_examples() to say that this is in an OpenApiResponse so we can assume that all is well if status_codes is None...

elif status_code not in example.status_codes:
warn(
f'example in response with status code {status_code} had'
f'status_codes set to {example.status_codes!r}'
)
Comment on lines +1178 to +1182
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't test this, but it struck me that if we're allowing status_codes to be None, or they can be pre-defined as the same value as the status code of the response, they could also have the wrong value.

Again, maybe this check should be pushed down into _get_examples() to be done when in this context of an OpenApiResponse to highlight that there is something wrong instead of silently ignoring it.

This check would be especially handy for the following (untested, hypothetical) situation:

unauthenticated_example = OpenApiExample(
    "UnauthenticatedExample",
    value={"error": "You are not authenticated."},
    status_codes=[401, 403],
)

@extend_schema(
    ...,
    responses={
        200: OpenApiResponse(description="...", examples=[unauthenticated_example]),  # Should raise warning...
        401: OpenApiResponse(description="...", examples=[unauthenticated_example]),
        403: OpenApiResponse(description="...", examples=[unauthenticated_example]),
)
@api_view(["POST"])
def view(request):
    ...

else:
description, examples = '', []

Expand Down Expand Up @@ -1207,7 +1217,7 @@ def _get_response_for_code(self, serializer, status_code, media_types=None):
if (
self._is_list_view(serializer)
and get_override(serializer, 'many') is not False
and ('200' <= status_code < '300' or spectacular_settings.ENABLE_LIST_MECHANICS_ON_NON_2XX)
Copy link
Contributor Author

@ngnpope ngnpope Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, this worked, but it's sort of a fluke. 😄

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this actually works quite well given that 2XX is also a valid value. unless you go totally crazy there, this always does what it is supposed to do.

and (is_success(status_code) or spectacular_settings.ENABLE_LIST_MECHANICS_ON_NON_2XX)
):
schema = build_array_type(schema)
paginator = self._get_paginator()
Expand Down Expand Up @@ -1244,7 +1254,7 @@ def _get_response_for_code(self, serializer, status_code, media_types=None):
'description': description
}

def _get_response_headers_for_code(self, status_code) -> dict:
def _get_response_headers_for_code(self, status_code: int) -> dict:
result = {}
for parameter in self.get_override_parameters():
if not isinstance(parameter, OpenApiParameter):
Expand All @@ -1253,7 +1263,7 @@ def _get_response_headers_for_code(self, status_code) -> dict:
continue
if (
isinstance(parameter.response, list)
and status_code not in [str(code) for code in parameter.response]
and status_code not in parameter.response
):
continue

Expand Down
5 changes: 3 additions & 2 deletions drf_spectacular/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import sys
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

from rest_framework.fields import Field, empty
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
response_only: bool = False,
parameter_only: Optional[Tuple[str, _ParameterLocationType]] = None,
media_type: str = 'application/json',
status_codes: Optional[List[str]] = None,
status_codes: Optional[List[Union[HTTPStatus, int, str]]] = None,
):
self.name = name
self.summary = summary
Expand All @@ -124,7 +125,7 @@ def __init__(
self.response_only = response_only
self.parameter_only = parameter_only
self.media_type = media_type
self.status_codes = status_codes or ['200', '201']
self.status_codes = list(map(int, status_codes)) if status_codes else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing list[int] | None early makes everything much easier downstream.



class OpenApiParameter(OpenApiSchemaBase):
Expand Down
47 changes: 44 additions & 3 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from http import HTTPStatus

from rest_framework import serializers, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response

from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
OpenApiExample, OpenApiParameter, extend_schema, extend_schema_serializer,
OpenApiExample, OpenApiParameter, OpenApiResponse, extend_schema, extend_schema_serializer,
)
from tests import assert_schema, generate_schema
from tests.models import SimpleModel
Expand Down Expand Up @@ -74,6 +76,8 @@ class ExampleTestWithExtendedViewSet(viewsets.GenericViewSet):
201: BSerializer,
400: OpenApiTypes.OBJECT,
403: OpenApiTypes.OBJECT,
404: OpenApiTypes.OBJECT,
500: OpenApiTypes.OBJECT,
},
examples=[
OpenApiExample(
Expand All @@ -94,7 +98,19 @@ class ExampleTestWithExtendedViewSet(viewsets.GenericViewSet):
'Create Error 403 Example',
value={'field': 'error'},
response_only=True,
status_codes=['403']
status_codes=['403'], # string
),
OpenApiExample(
'Create Error 404 Example',
value={'field': 'error'},
response_only=True,
status_codes=[404], # integer
),
OpenApiExample(
'Create Error 500 Example',
value={'field': 'error'},
response_only=True,
status_codes=[HTTPStatus.INTERNAL_SERVER_ERROR], # enum
),
],
)
Expand Down Expand Up @@ -143,7 +159,32 @@ def retrieve(self, request):
def raw_action(self, request):
return Response() # pragma: no cover

@extend_schema(responses=BSerializer)
@extend_schema(
responses={
200: OpenApiResponse(
description="",
response=BSerializer,
examples=[
OpenApiExample(
'Override 200 Example',
value={'field': 'ok'},
),
],
),
400: OpenApiResponse(
description="",
response=BSerializer,
examples=[
OpenApiExample(
'Override 400 Example',
value={'field': 'status_codes_undeclared_in_response_examples'},
# Ensure this works when status_codes are not provided.
# Previously this only worked automatically for 200 and 201.
),
],
),
},
)
@action(detail=False, methods=['POST'])
def override_extend_schema_action(self, request):
return Response() # pragma: no cover
Expand Down
40 changes: 40 additions & 0 deletions tests/test_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ paths:
field: error
summary: Create Error 403 Example
description: ''
'404':
content:
application/json:
schema:
type: object
additionalProperties: {}
examples:
CreateError404Example:
value:
field: error
summary: Create Error 404 Example
description: ''
'500':
content:
application/json:
schema:
type: object
additionalProperties: {}
examples:
CreateError500Example:
value:
field: error
summary: Create Error 500 Example
description: ''
/schema/{id}/:
get:
operationId: schema_retrieve
Expand Down Expand Up @@ -175,6 +199,22 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/B'
examples:
Override200Example:
value:
field: ok
summary: Override 200 Example
description: ''
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/B'
examples:
Override400Example:
value:
field: status_codes_undeclared_in_response_examples
summary: Override 400 Example
description: ''
/schema/raw_action/:
get:
Expand Down
40 changes: 39 additions & 1 deletion tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from decimal import Decimal
from functools import partialmethod
from http import HTTPStatus
from unittest import mock

import pytest
Expand All @@ -14,7 +15,7 @@
from django.urls import path, re_path, register_converter
from django.urls.converters import StringConverter
from rest_framework import (
filters, generics, mixins, pagination, parsers, renderers, routers, serializers, views,
filters, generics, mixins, pagination, parsers, renderers, routers, serializers, status, views,
viewsets,
)
from rest_framework.authentication import BasicAuthentication, TokenAuthentication
Expand Down Expand Up @@ -2753,3 +2754,40 @@ def view_func(request, format=None):
'wo': {'type': 'string', 'writeOnly': True, 'minLength': 1},
'rw': {'type': 'string', 'minLength': 1}
}


def test_response_status_codes_types(no_warnings):
class XSerializer(serializers.Serializer):
field = serializers.IntegerField()

@extend_schema(
request=XSerializer,
responses={
200: OpenApiResponse(
description='Integer status code.',
response=XSerializer,
),
'400': OpenApiResponse(
description='String status code.',
response=XSerializer,
),
status.HTTP_401_UNAUTHORIZED: OpenApiResponse(
description='DRF constant.',
response=XSerializer,
),
HTTPStatus.FORBIDDEN: OpenApiResponse(
description='http.HTTPStatus enum.',
response=XSerializer,
),
},
)
@api_view(['POST'])
def pi(request, format=None):
pass # pragma: no cover

schema = generate_schema('/x', view_function=pi)
operation = schema['paths']['/x']['post']
assert operation['responses']['200']['description'] == 'Integer status code.'
assert operation['responses']['400']['description'] == 'String status code.'
assert operation['responses']['401']['description'] == 'DRF constant.'
assert operation['responses']['403']['description'] == 'http.HTTPStatus enum.'