diff --git a/docs/Usage/Response.md b/docs/Usage/Response.md index 8d8300b9..023e9fa9 100644 --- a/docs/Usage/Response.md +++ b/docs/Usage/Response.md @@ -53,9 +53,46 @@ def hello(path: HelloPath): return response ``` - ![image-20210526104627124](../assets/image-20210526104627124.png) +## Validate responses + +By default, responses are not validated. If you need to validate responses, set validate_responses to True. Here are +several ways to achieve this: + +```python +# 1. APP level +app = OpenAPI(__name__, validate_response=True) + +# 2. APIBlueprint level +api = APIBlueprint(__name__, validate_response=True) + +# 3. APIView level +@api_view.route("/test") +class TestAPI: + @api_view.doc(responses={201: Response}, validate_response=True) + def post(self): + ... + +# 4. api level +@app.post("/test", responses={201: Response}, validate_response=True) +def endpoint_test(body: BaseRequest): + ... +``` + +You can also customize the default behavior of response validation by using a custom `validate_response_callback`. + +```python + +def validate_response_callback(response: Any, responses: Optional[ResponseDict] = None) -> Any: + + # do something + + return response + +app = OpenAPI(__name__, validate_response=True, validate_response_callback=validate_response_callback) +``` + ## More information about OpenAPI responses - [OpenAPI Responses Object](https://spec.openapis.org/oas/v3.1.0#responses-object), it includes the Response Object. diff --git a/flask_openapi3/blueprint.py b/flask_openapi3/blueprint.py index b0b85c39..c49735b7 100644 --- a/flask_openapi3/blueprint.py +++ b/flask_openapi3/blueprint.py @@ -33,6 +33,7 @@ def __init__( abp_responses: Optional[ResponseDict] = None, doc_ui: bool = True, operation_id_callback: Callable = get_operation_id_for_path, + validate_response: Optional[bool] = None, **kwargs: Any ) -> None: """ @@ -49,6 +50,7 @@ def __init__( operation_id_callback: Callback function for custom operation_id generation. Receives name (str), path (str) and method (str) parameters. Defaults to `get_operation_id_for_path` from utils + validate_response: Verify the response body. **kwargs: Flask Blueprint kwargs """ super(APIBlueprint, self).__init__(name, import_name, **kwargs) @@ -71,6 +73,9 @@ def __init__( # Set the operation ID callback function self.operation_id_callback: Callable = operation_id_callback + # Verify the response body + self.validate_response = validate_response + def register_api(self, api: "APIBlueprint") -> None: """Register a nested APIBlueprint""" diff --git a/flask_openapi3/models/components.py b/flask_openapi3/models/components.py index 1bdbeb15..c498b52e 100644 --- a/flask_openapi3/models/components.py +++ b/flask_openapi3/models/components.py @@ -3,7 +3,7 @@ # @Time : 2023/7/4 9:36 from typing import Optional, Union, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel from .callback import Callback from .example import Example @@ -23,7 +23,7 @@ class Components(BaseModel): https://spec.openapis.org/oas/v3.1.0#components-object """ - schemas: Optional[dict[str, Union[Reference, Schema]]] = Field(None) + schemas: Optional[dict[str, Union[Reference, Schema]]] = None responses: Optional[dict[str, Union[Response, Reference]]] = None parameters: Optional[dict[str, Union[Parameter, Reference]]] = None examples: Optional[dict[str, Union[Example, Reference]]] = None diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index 5c1c91e8..91b1a076 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -42,6 +42,7 @@ from .utils import parse_and_store_tags from .utils import parse_method from .utils import parse_parameters +from .utils import run_validate_response from .view import APIView @@ -63,6 +64,8 @@ def __init__( doc_ui: bool = True, doc_prefix: str = "/openapi", doc_url: str = "/openapi.json", + validate_response: Optional[bool] = None, + validate_response_callback: Callable = run_validate_response, **kwargs: Any ) -> None: """ @@ -95,6 +98,8 @@ def __init__( Defaults to "/openapi". doc_url: URL for accessing the OpenAPI specification document in JSON format. Defaults to "/openapi.json". + validate_response: Verify the response body. + validate_response_callback: Validation and return response. **kwargs: Additional kwargs to be passed to Flask. """ super(OpenAPI, self).__init__(import_name, **kwargs) @@ -144,6 +149,10 @@ def __init__( # Add the OpenAPI command self.cli.add_command(openapi_command) # type: ignore + # Verify the response body + self.validate_response = validate_response + self.validate_response_callback = validate_response_callback + # Initialize specification JSON self.spec_json: dict = {} self.spec = APISpec( diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 0e38e5ac..728c8192 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -5,6 +5,7 @@ from functools import wraps from typing import Callable, Optional, Any +from flask import current_app from flask.wrappers import Response as FlaskResponse from .models import ExternalDocumentation @@ -35,10 +36,10 @@ def _collect_openapi_info( doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def register_api(self, api) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _add_url_rule( self, @@ -48,10 +49,11 @@ def _add_url_rule( provide_automatic_options=None, **options, ) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover @staticmethod def create_view_func( + # self, func, header, cookie, @@ -61,8 +63,11 @@ def create_view_func( body, raw, view_class=None, - view_kwargs=None + view_kwargs=None, + responses: Optional[ResponseDict] = None, + validate_response: Optional[bool] = None, ): + is_coroutine_function = inspect.iscoroutinefunction(func) if is_coroutine_function: @wraps(func) @@ -89,6 +94,19 @@ async def view_func(**kwargs) -> FlaskResponse: response = await func(view_object, **func_kwargs) else: response = await func(**func_kwargs) + + if hasattr(current_app, "validate_response"): + if validate_response is None: + _validate_response = current_app.validate_response + else: + _validate_response = validate_response + else: + _validate_response = validate_response + + if _validate_response and responses: + validate_response_callback = getattr(current_app, "validate_response_callback") + return validate_response_callback(response, responses) + return response else: @wraps(func) @@ -115,6 +133,19 @@ def view_func(**kwargs) -> FlaskResponse: response = func(view_object, **func_kwargs) else: response = func(**func_kwargs) + + if hasattr(current_app, "validate_response"): + if validate_response is None: + _validate_response = current_app.validate_response + else: + _validate_response = validate_response + else: + _validate_response = validate_response + + if _validate_response and responses: + validate_response_callback = getattr(current_app, "validate_response_callback") + return validate_response_callback(response, responses) + return response if not hasattr(func, "view"): @@ -137,6 +168,7 @@ def get( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -156,6 +188,7 @@ def get( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -177,7 +210,20 @@ def decorator(func) -> Callable: method=HTTPMethod.GET ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) + options.update({"methods": [HTTPMethod.GET]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -200,6 +246,7 @@ def post( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -219,6 +266,7 @@ def post( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -240,7 +288,20 @@ def decorator(func) -> Callable: method=HTTPMethod.POST ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) + options.update({"methods": [HTTPMethod.POST]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -263,6 +324,7 @@ def put( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -282,6 +344,7 @@ def put( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -303,7 +366,20 @@ def decorator(func) -> Callable: method=HTTPMethod.PUT ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) + options.update({"methods": [HTTPMethod.PUT]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -326,6 +402,7 @@ def delete( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -345,6 +422,7 @@ def delete( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -366,7 +444,20 @@ def decorator(func) -> Callable: method=HTTPMethod.DELETE ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) + options.update({"methods": [HTTPMethod.DELETE]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -389,6 +480,7 @@ def patch( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -408,6 +500,7 @@ def patch( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -429,10 +522,28 @@ def decorator(func) -> Callable: method=HTTPMethod.PATCH ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) + options.update({"methods": [HTTPMethod.PATCH]}) self._add_url_rule(rule, view_func=view_func, **options) return func return decorator + + def get_validate_response(self): + if hasattr(self, "validate_response"): + if self.validate_response is not None: + return self.validate_response diff --git a/flask_openapi3/utils.py b/flask_openapi3/utils.py index 205a8cfc..be9de8ba 100644 --- a/flask_openapi3/utils.py +++ b/flask_openapi3/utils.py @@ -592,6 +592,41 @@ def make_validation_error_response(e: ValidationError) -> FlaskResponse: return response +def run_validate_response(response: Any, responses: Optional[ResponseDict] = None) -> Any: + """Validate response""" + if responses is None: + return response + + if isinstance(response, tuple): # noqa + _resp, status_code = response[:2] + elif isinstance(response, FlaskResponse): + if response.mimetype != "application/json": + # only application/json + return response + _resp, status_code = response.json, response.status_code # noqa + else: + _resp, status_code = response, 200 + + # status_code is http.HTTPStatus + if isinstance(status_code, HTTPStatus): + status_code = status_code.value + + resp_model = responses.get(status_code) + + if resp_model is None: + return response + + assert inspect.isclass(resp_model) and \ + issubclass(resp_model, BaseModel), f"{resp_model} is invalid `pydantic.BaseModel`" + + if isinstance(_resp, str): + resp_model.model_validate_json(_resp) + else: + resp_model.model_validate(_resp) + + return response + + def parse_rule(rule: str, url_prefix=None) -> str: trail_slash = rule.endswith("/") diff --git a/flask_openapi3/view.py b/flask_openapi3/view.py index 23711bf9..418b1f11 100644 --- a/flask_openapi3/view.py +++ b/flask_openapi3/view.py @@ -31,6 +31,7 @@ def __init__( view_responses: Optional[ResponseDict] = None, doc_ui: bool = True, operation_id_callback: Callable = get_operation_id_for_path, + validate_response: Optional[bool] = None ): """ Create a class-based view @@ -44,6 +45,7 @@ def __init__( operation_id_callback: Callback function for custom operation_id generation. Receives name (str), path (str) and method (str) parameters. Defaults to `get_operation_id_for_path` from utils + validate_response: Verify the response body. """ self.url_prefix = url_prefix self.view_tags = view_tags or [] @@ -61,6 +63,8 @@ def __init__( self.tags: list[Tag] = [] self.tag_names: list[str] = [] + self.validate_response = validate_response + def route(self, rule: str): """Decorator for view class""" @@ -112,7 +116,8 @@ def doc( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, - doc_ui: bool = True + doc_ui: bool = True, + validate_response: Optional[bool] = None, ) -> Callable: """ Decorator for view method. @@ -130,6 +135,7 @@ def doc( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ new_responses = convert_responses_key_to_string(responses or {}) @@ -137,6 +143,9 @@ def doc( tags = tags + self.view_tags if tags else self.view_tags def decorator(func): + func.validate_response = validate_response + func.responses = responses + if self.doc_ui is False or doc_ui is False: return func @@ -194,7 +203,7 @@ def register( self, app: "OpenAPI", url_prefix: Optional[str] = None, - view_kwargs: Optional[dict[Any, Any]] = None + view_kwargs: Optional[dict[Any, Any]] = None, ) -> None: """ Register the API views with the given OpenAPI app. @@ -204,9 +213,14 @@ def register( url_prefix: A path to prepend to all the APIView's urls view_kwargs: Additional keyword arguments to pass to the API views. """ + for rule, (cls, methods) in self.views.items(): for method in methods: func = getattr(cls, method.lower()) + if func.validate_response is not None: + _validate_response = func.validate_response + else: + _validate_response = self.validate_response header, cookie, path, query, form, body, raw = parse_parameters(func, doc_ui=False) view_func = app.create_view_func( func, @@ -218,7 +232,9 @@ def register( body, raw, view_class=cls, - view_kwargs=view_kwargs + view_kwargs=view_kwargs, + responses=func.responses, + validate_response=_validate_response, ) if url_prefix and self.url_prefix and url_prefix != self.url_prefix: diff --git a/tests/test_validate_responses.py b/tests/test_validate_responses.py new file mode 100644 index 00000000..f4a6b301 --- /dev/null +++ b/tests/test_validate_responses.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import pytest +from pydantic import BaseModel, ValidationError + +from flask_openapi3 import OpenAPI, APIView +from flask_openapi3.blueprint import APIBlueprint + + +class BaseRequest(BaseModel): + """Base description""" + test_int: int + test_str: str + + +class GoodResponse(BaseRequest): ... + + +class BadResponse(BaseModel): + test_int: str + test_str: str + + +def test_no_validate_response(request): + """ + Response validation defaults to no validation + Response doesn't match schema and doesn't raise any errors + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_app_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name, validate_response=True) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_app_api_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}, validate_response=True) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_abp_level_no_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__) + + @test_abp.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_abp_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__, validate_response=True) + + @test_abp.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_abp_api_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__) + + @test_abp.post("/test", responses={201: BadResponse}, validate_response=True) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_apiview_no_validate_response(request): + """ + Response validation defaults to no validation + Response doesn't match schema and doesn't raise any errors + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_apiview_app_level_validate_response(request): + """ + Validation turned on at app level + """ + + test_app = OpenAPI(request.node.name, validate_response=True) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_apiview_api_level_validate_response(request): + """ + Validation turned on at app level + """ + + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}, validate_response=True) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"})