From 38eb6de0b4fe05ebea44b785b58f8e8c618858e5 Mon Sep 17 00:00:00 2001 From: Mahdi Mairza Date: Sat, 10 May 2025 18:37:46 +0100 Subject: [PATCH 1/3] ISSUE-220 allow for response validation --- flask_openapi3/blueprint.py | 3 + flask_openapi3/openapi.py | 3 + flask_openapi3/scaffold.py | 145 ++++++++++++++++- flask_openapi3/view.py | 31 +++- tests/test_validate_responses.py | 260 +++++++++++++++++++++++++++++++ 5 files changed, 433 insertions(+), 9 deletions(-) create mode 100644 tests/test_validate_responses.py diff --git a/flask_openapi3/blueprint.py b/flask_openapi3/blueprint.py index b0b85c39..1cf605fe 100644 --- a/flask_openapi3/blueprint.py +++ b/flask_openapi3/blueprint.py @@ -33,6 +33,7 @@ def __init__( abp_responses: Optional[ResponseDict] = None, doc_ui: bool = True, operation_id_callback: Callable = get_operation_id_for_path, + validate_response: Optional[bool] = None, **kwargs: Any ) -> None: """ @@ -71,6 +72,8 @@ def __init__( # Set the operation ID callback function self.operation_id_callback: Callable = operation_id_callback + self.abp_validate_response: Optional[bool] = validate_response + def register_api(self, api: "APIBlueprint") -> None: """Register a nested APIBlueprint""" diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index 5c1c91e8..7ffd0c45 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -63,6 +63,7 @@ def __init__( doc_ui: bool = True, doc_prefix: str = "/openapi", doc_url: str = "/openapi.json", + validate_response: Optional[bool] = None, **kwargs: Any ) -> None: """ @@ -144,6 +145,8 @@ def __init__( # Add the OpenAPI command self.cli.add_command(openapi_command) # type: ignore + self.app_validate_response: Optional[bool] = validate_response + # Initialize specification JSON self.spec_json: dict = {} self.spec = APISpec( diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 0e38e5ac..4f6a54d5 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -6,6 +6,10 @@ from typing import Callable, Optional, Any from flask.wrappers import Response as FlaskResponse +from flask import current_app, Response as _Response +from http import HTTPStatus +from pydantic import BaseModel +from typing import Type, Dict from .models import ExternalDocumentation from .models import Server @@ -16,6 +20,100 @@ from .utils import HTTPMethod +def is_response_validation_enabled(validate_response: Optional[bool] = None, api_validate_response: Optional[bool] = None): + """ + Check if response validation is applicable. + + If different levels are set, priority follows this order + 1. Set at the api/route/single endpoint (api_response_validate) + 2. Set at APIBlueprint creation (response_validate) + 3. Set at OpenAPI creation (response_validate) + 4. Set in config `FLASK_OPENAPI_VALIDATE_RESPONSE` + + NOTE: #2 and #3 come from response_validate. + NOTE: What about APIView; should we not be able to set there as well? + NOTE: Should there be any inheritance this validation, IE: nested ABPs? + + Args: + response_validate: Whether the App/API-Blueprint wants to validate responses (context dependant) + route_response_validate: Whether the route wants to validate responses + """ + + global_response_validate: bool = current_app.config.get("FLASK_OPENAPI_VALIDATE_RESPONSE", False) + + if api_validate_response: + return True + + if validate_response: + return True + + return global_response_validate + + +# def run_validate_response(resp: Any, responses: Optional[Dict[str, Type[BaseModel]]] = None) -> None: +def run_validate_response(resp: Any, responses: Optional[ResponseDict] = None) -> None: + """Validate response""" + + # TODO: strict-mode? if a response is json and doesn't have a response status as well as not having + # a model to validate this should be flagged as an issue? which would necessitate the response + # to always return the correct/supported statuses as defined in "resposnes" + + warn = not current_app.config.get("FLASK_OPENAPI_DISABLE_WARNINGS", False) + + if warn: + print("Warning: " + "You are using `FLASK_OPENAPI_VALIDATE_RESPONSE=True`, " + "please do not use it in the production environment, " + "because it will reduce the performance. " + "NOTE, you can disable this warning with `Flask.config['FLASK_OPENAPI_DISABLE_WARNINGS'] = True`") + + + if not responses: + print("Warning, response validation on but endpoint has no responses set") + return + + if isinstance(resp, tuple): # noqa + _resp, status_code = resp[:2] + + elif isinstance(resp, _Response): + if resp.mimetype != "application/json": + # only application/json + return + # raise TypeError("`Response` mimetype must be application/json.") + _resp, status_code = resp.json, resp.status_code # noqa + + else: + _resp, status_code = resp, 200 + + # status_code is http.HTTPStatus + if isinstance(status_code, HTTPStatus): + status_code = status_code.value + + resp_model = responses.get(status_code) + + if resp_model is None: + if warn: + print("Warning: missing status code map to `pydantic.BaseModel`") + + return + + assert inspect.isclass(resp_model) and \ + issubclass(resp_model, BaseModel), f"{resp_model} is invalid `pydantic.BaseModel`" + + if isinstance(_resp, str): + resp_model.model_validate_json(_resp) + + elif not isinstance(_resp, dict): + resp_model.model_validate(_resp) + + else: + try: + resp_model(**_resp) + + except TypeError: + raise TypeError(f"`{resp_model.__name__}` validation failed, must be a mapping.") + + class APIScaffold: def _collect_openapi_info( self, @@ -52,6 +150,7 @@ def _add_url_rule( @staticmethod def create_view_func( + # self, func, header, cookie, @@ -60,9 +159,13 @@ def create_view_func( form, body, raw, + responses: Optional[ResponseDict] = None, view_class=None, - view_kwargs=None + view_kwargs=None, + parent_validate_response: Optional[bool] = None, + api_validate_response: Optional[bool] = None, ): + is_coroutine_function = inspect.iscoroutinefunction(func) if is_coroutine_function: @wraps(func) @@ -89,6 +192,10 @@ async def view_func(**kwargs) -> FlaskResponse: response = await func(view_object, **func_kwargs) else: response = await func(**func_kwargs) + + if is_response_validation_enabled(validate_response=parent_validate_response, api_validate_response=api_validate_response) and responses: + run_validate_response(response, responses) + return response else: @wraps(func) @@ -115,6 +222,10 @@ def view_func(**kwargs) -> FlaskResponse: response = func(view_object, **func_kwargs) else: response = func(**func_kwargs) + + if is_response_validation_enabled(validate_response=parent_validate_response, api_validate_response=api_validate_response): + run_validate_response(response, responses) + return response if not hasattr(func, "view"): @@ -137,6 +248,7 @@ def get( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -177,7 +289,9 @@ def decorator(func) -> Callable: method=HTTPMethod.GET ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + parent_validate_response = self.get_parent_validation() + view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + options.update({"methods": [HTTPMethod.GET]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -200,6 +314,7 @@ def post( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: bool = False, **options: Any ) -> Callable: """ @@ -240,7 +355,9 @@ def decorator(func) -> Callable: method=HTTPMethod.POST ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + parent_validate_response = self.get_parent_validation() + view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + options.update({"methods": [HTTPMethod.POST]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -263,6 +380,7 @@ def put( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: bool = False, **options: Any ) -> Callable: """ @@ -303,7 +421,9 @@ def decorator(func) -> Callable: method=HTTPMethod.PUT ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + parent_validate_response = self.get_parent_validation() + view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + options.update({"methods": [HTTPMethod.PUT]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -326,6 +446,7 @@ def delete( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: bool = False, **options: Any ) -> Callable: """ @@ -366,7 +487,9 @@ def decorator(func) -> Callable: method=HTTPMethod.DELETE ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + parent_validate_response = self.get_parent_validation() + view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + options.update({"methods": [HTTPMethod.DELETE]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -389,6 +512,7 @@ def patch( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, + validate_response: bool = False, **options: Any ) -> Callable: """ @@ -429,10 +553,19 @@ def decorator(func) -> Callable: method=HTTPMethod.PATCH ) - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw) + parent_validate_response = self.get_parent_validation() + view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + options.update({"methods": [HTTPMethod.PATCH]}) self._add_url_rule(rule, view_func=view_func, **options) return func return decorator + + def get_parent_validation(self): + # NOTE: abp_ vs app_ distinction without a difference in this context? + if hasattr(self, "abp_validate_response"): + return self.abp_validate_response + + return self.app_validate_response diff --git a/flask_openapi3/view.py b/flask_openapi3/view.py index 23711bf9..b5532f04 100644 --- a/flask_openapi3/view.py +++ b/flask_openapi3/view.py @@ -112,7 +112,8 @@ def doc( security: Optional[list[dict[str, list[Any]]]] = None, servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, - doc_ui: bool = True + doc_ui: bool = True, + validate_response: Optional[bool] = None, ) -> Callable: """ Decorator for view method. @@ -132,11 +133,15 @@ def doc( doc_ui: Declares this operation to be shown. Default to True. """ + # import ipdb; ipdb.set_trace() new_responses = convert_responses_key_to_string(responses or {}) security = security or [] tags = tags + self.view_tags if tags else self.view_tags def decorator(func): + func.validate_response = validate_response + func.responses = responses + if self.doc_ui is False or doc_ui is False: return func @@ -186,6 +191,10 @@ def decorator(func): get_responses(combine_responses, self.components_schemas, operation) func.operation = operation + # NOTE: combite_responses instead of responses here? or is above enough? + # func.responses = responses + # func.responses = combite_responses + return func return decorator @@ -194,7 +203,7 @@ def register( self, app: "OpenAPI", url_prefix: Optional[str] = None, - view_kwargs: Optional[dict[Any, Any]] = None + view_kwargs: Optional[dict[Any, Any]] = None, ) -> None: """ Register the API views with the given OpenAPI app. @@ -204,9 +213,21 @@ def register( url_prefix: A path to prepend to all the APIView's urls view_kwargs: Additional keyword arguments to pass to the API views. """ + for rule, (cls, methods) in self.views.items(): for method in methods: func = getattr(cls, method.lower()) + + if isinstance(func.responses, dict): + responses = func.responses.copy() + else: + responses = func.responses + + validate_response = func.validate_response + + del func.responses + del func.validate_response + header, cookie, path, query, form, body, raw = parse_parameters(func, doc_ui=False) view_func = app.create_view_func( func, @@ -217,8 +238,12 @@ def register( form, body, raw, + responses=responses, view_class=cls, - view_kwargs=view_kwargs + view_kwargs=view_kwargs, + parent_validate_response=app.app_validate_response, + # NOTE: do we support this at APIView definition time or do we want it at each class/route? or not at all? + api_validate_response=validate_response, ) if url_prefix and self.url_prefix and url_prefix != self.url_prefix: diff --git a/tests/test_validate_responses.py b/tests/test_validate_responses.py new file mode 100644 index 00000000..1cc080b6 --- /dev/null +++ b/tests/test_validate_responses.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import pytest + + +from pydantic import BaseModel, ValidationError + +from flask_openapi3 import OpenAPI, APIView +from flask_openapi3.blueprint import APIBlueprint + + +class BaseRequest(BaseModel): + """Base description""" + test_int: int + test_str: str + + +class GoodResponse(BaseRequest): ... + + +class BadResponse(BaseModel): + test_int: str + test_str: str + + +def test_no_validate_response(request): + """ + Response validation defaults to no validation + Response doesn't match schema and doesn't raise any errors + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_app_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name, validate_response=True) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_app_api_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + + @test_app.post("/test", responses={201: BadResponse}, validate_response=True) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_app_config_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_app.config["FLASK_OPENAPI_VALIDATE_RESPONSE"] = True + + @test_app.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + + +def test_abp_level_no_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__) + + @test_abp.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_abp_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__, validate_response=True) + + @test_abp.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_abp_api_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_abp = APIBlueprint("abp", __name__) + + @test_abp.post("/test", responses={201: BadResponse}, validate_response=True) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_abp_config_level_validate_response(request): + """ + Validation turned on at app level + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_app.config["FLASK_OPENAPI_VALIDATE_RESPONSE"] = True + test_abp = APIBlueprint("abp", __name__) + + @test_abp.post("/test", responses={201: BadResponse}) + def endpoint_test(body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api(test_abp) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_apiview_no_validate_response(request): + """ + Response validation defaults to no validation + Response doesn't match schema and doesn't raise any errors + """ + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + resp = client.post("/test", json={"test_int": 1, "test_str": "s"}) + assert resp.status_code == 201 + + +def test_apiview_app_level_validate_response(request): + """ + Validation turned on at app level + """ + + test_app = OpenAPI(request.node.name, validate_response=True) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +def test_apiview_api_level_validate_response(request): + """ + Validation turned on at app level + """ + + test_app = OpenAPI(request.node.name) + test_app.config["TESTING"] = True + test_api_view = APIView("") + + @test_api_view.route("/test") + class TestAPI: + @test_api_view.doc(responses={201: BadResponse}, validate_response=True) + def post(self, body: BaseRequest): + return body.model_dump(), 201 + + test_app.register_api_view(test_api_view) + + with test_app.test_client() as client: + with pytest.raises(ValidationError): + _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) + + +# NOTE: can't register APIView on APIBlueprint - missing feature? +# def test_apiview_abp_level_validate_response(request): +# """ +# Validation turned on at app level +# """ +# +# test_app = OpenAPI(request.node.name, validate_response=True) +# test_abp = APIBlueprint("abp", __name__) +# test_app.config["TESTING"] = True +# test_api_view = APIView() +# +# @test_api_view.route("/test") +# class TestAPI: +# +# @test_api_view.doc(responses={201: BadResponse}) +# def post(self, body: BaseRequest): +# return body.model_dump(), 201 +# +# test_abp.register_api_view(test_api_view) +# test_app.register_api(test_abp) +# +# +# with test_app.test_client() as client: +# with pytest.raises(ValidationError): +# _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) From 7248d4682856588e06844d7ff020b2ed1bb076c0 Mon Sep 17 00:00:00 2001 From: luolingchun Date: Mon, 19 May 2025 16:41:03 +0800 Subject: [PATCH 2/3] fix --- docs/Usage/Response.md | 26 +++- flask_openapi3/blueprint.py | 4 +- flask_openapi3/openapi.py | 4 +- flask_openapi3/scaffold.py | 229 ++++++++++++++----------------- flask_openapi3/utils.py | 33 +++++ flask_openapi3/view.py | 29 ++-- tests/test_validate_responses.py | 67 --------- 7 files changed, 177 insertions(+), 215 deletions(-) diff --git a/docs/Usage/Response.md b/docs/Usage/Response.md index 8d8300b9..3efa0ba5 100644 --- a/docs/Usage/Response.md +++ b/docs/Usage/Response.md @@ -53,9 +53,33 @@ def hello(path: HelloPath): return response ``` - ![image-20210526104627124](../assets/image-20210526104627124.png) +## Validate responses + +By default, responses are not validated. If you need to validate responses, set validate_responses to True. Here are +several ways to achieve this: + +```python +# 1. APP level +app = OpenAPI(__name__, validate_response=True) + +# 2. APIBlueprint level +api = APIBlueprint(__name__, validate_response=True) + +# 3. APIView level +@api_view.route("/test") +class TestAPI: + @api_view.doc(responses={201: Response}, validate_response=True) + def post(self): + ... + +# 4. api level +@app.post("/test", responses={201: Response}, validate_response=True) +def endpoint_test(body: BaseRequest): + ... +``` + ## More information about OpenAPI responses - [OpenAPI Responses Object](https://spec.openapis.org/oas/v3.1.0#responses-object), it includes the Response Object. diff --git a/flask_openapi3/blueprint.py b/flask_openapi3/blueprint.py index 1cf605fe..c49735b7 100644 --- a/flask_openapi3/blueprint.py +++ b/flask_openapi3/blueprint.py @@ -50,6 +50,7 @@ def __init__( operation_id_callback: Callback function for custom operation_id generation. Receives name (str), path (str) and method (str) parameters. Defaults to `get_operation_id_for_path` from utils + validate_response: Verify the response body. **kwargs: Flask Blueprint kwargs """ super(APIBlueprint, self).__init__(name, import_name, **kwargs) @@ -72,7 +73,8 @@ def __init__( # Set the operation ID callback function self.operation_id_callback: Callable = operation_id_callback - self.abp_validate_response: Optional[bool] = validate_response + # Verify the response body + self.validate_response = validate_response def register_api(self, api: "APIBlueprint") -> None: """Register a nested APIBlueprint""" diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index 7ffd0c45..e48f21a7 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -96,6 +96,7 @@ def __init__( Defaults to "/openapi". doc_url: URL for accessing the OpenAPI specification document in JSON format. Defaults to "/openapi.json". + validate_response: Verify the response body. **kwargs: Additional kwargs to be passed to Flask. """ super(OpenAPI, self).__init__(import_name, **kwargs) @@ -145,7 +146,8 @@ def __init__( # Add the OpenAPI command self.cli.add_command(openapi_command) # type: ignore - self.app_validate_response: Optional[bool] = validate_response + # Verify the response body + self.validate_response = validate_response # Initialize specification JSON self.spec_json: dict = {} diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 4f6a54d5..da43469d 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -5,11 +5,8 @@ from functools import wraps from typing import Callable, Optional, Any +from flask import current_app from flask.wrappers import Response as FlaskResponse -from flask import current_app, Response as _Response -from http import HTTPStatus -from pydantic import BaseModel -from typing import Type, Dict from .models import ExternalDocumentation from .models import Server @@ -18,100 +15,7 @@ from .types import ParametersTuple from .types import ResponseDict from .utils import HTTPMethod - - -def is_response_validation_enabled(validate_response: Optional[bool] = None, api_validate_response: Optional[bool] = None): - """ - Check if response validation is applicable. - - If different levels are set, priority follows this order - 1. Set at the api/route/single endpoint (api_response_validate) - 2. Set at APIBlueprint creation (response_validate) - 3. Set at OpenAPI creation (response_validate) - 4. Set in config `FLASK_OPENAPI_VALIDATE_RESPONSE` - - NOTE: #2 and #3 come from response_validate. - NOTE: What about APIView; should we not be able to set there as well? - NOTE: Should there be any inheritance this validation, IE: nested ABPs? - - Args: - response_validate: Whether the App/API-Blueprint wants to validate responses (context dependant) - route_response_validate: Whether the route wants to validate responses - """ - - global_response_validate: bool = current_app.config.get("FLASK_OPENAPI_VALIDATE_RESPONSE", False) - - if api_validate_response: - return True - - if validate_response: - return True - - return global_response_validate - - -# def run_validate_response(resp: Any, responses: Optional[Dict[str, Type[BaseModel]]] = None) -> None: -def run_validate_response(resp: Any, responses: Optional[ResponseDict] = None) -> None: - """Validate response""" - - # TODO: strict-mode? if a response is json and doesn't have a response status as well as not having - # a model to validate this should be flagged as an issue? which would necessitate the response - # to always return the correct/supported statuses as defined in "resposnes" - - warn = not current_app.config.get("FLASK_OPENAPI_DISABLE_WARNINGS", False) - - if warn: - print("Warning: " - "You are using `FLASK_OPENAPI_VALIDATE_RESPONSE=True`, " - "please do not use it in the production environment, " - "because it will reduce the performance. " - "NOTE, you can disable this warning with `Flask.config['FLASK_OPENAPI_DISABLE_WARNINGS'] = True`") - - - if not responses: - print("Warning, response validation on but endpoint has no responses set") - return - - if isinstance(resp, tuple): # noqa - _resp, status_code = resp[:2] - - elif isinstance(resp, _Response): - if resp.mimetype != "application/json": - # only application/json - return - # raise TypeError("`Response` mimetype must be application/json.") - _resp, status_code = resp.json, resp.status_code # noqa - - else: - _resp, status_code = resp, 200 - - # status_code is http.HTTPStatus - if isinstance(status_code, HTTPStatus): - status_code = status_code.value - - resp_model = responses.get(status_code) - - if resp_model is None: - if warn: - print("Warning: missing status code map to `pydantic.BaseModel`") - - return - - assert inspect.isclass(resp_model) and \ - issubclass(resp_model, BaseModel), f"{resp_model} is invalid `pydantic.BaseModel`" - - if isinstance(_resp, str): - resp_model.model_validate_json(_resp) - - elif not isinstance(_resp, dict): - resp_model.model_validate(_resp) - - else: - try: - resp_model(**_resp) - - except TypeError: - raise TypeError(f"`{resp_model.__name__}` validation failed, must be a mapping.") +from .utils import run_validate_response class APIScaffold: @@ -133,10 +37,10 @@ def _collect_openapi_info( doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def register_api(self, api) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _add_url_rule( self, @@ -146,7 +50,7 @@ def _add_url_rule( provide_automatic_options=None, **options, ) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover @staticmethod def create_view_func( @@ -159,11 +63,10 @@ def create_view_func( form, body, raw, - responses: Optional[ResponseDict] = None, view_class=None, view_kwargs=None, - parent_validate_response: Optional[bool] = None, - api_validate_response: Optional[bool] = None, + responses: Optional[ResponseDict] = None, + validate_response: Optional[bool] = None, ): is_coroutine_function = inspect.iscoroutinefunction(func) @@ -193,7 +96,15 @@ async def view_func(**kwargs) -> FlaskResponse: else: response = await func(**func_kwargs) - if is_response_validation_enabled(validate_response=parent_validate_response, api_validate_response=api_validate_response) and responses: + if hasattr(current_app, "validate_response"): + if validate_response is None: + _validate_response = current_app.validate_response + else: + _validate_response = validate_response + else: + _validate_response = validate_response + + if _validate_response and responses: run_validate_response(response, responses) return response @@ -223,7 +134,15 @@ def view_func(**kwargs) -> FlaskResponse: else: response = func(**func_kwargs) - if is_response_validation_enabled(validate_response=parent_validate_response, api_validate_response=api_validate_response): + if hasattr(current_app, "validate_response"): + if validate_response is None: + _validate_response = current_app.validate_response + else: + _validate_response = validate_response + else: + _validate_response = validate_response + + if _validate_response and responses: run_validate_response(response, responses) return response @@ -268,6 +187,7 @@ def get( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -289,8 +209,19 @@ def decorator(func) -> Callable: method=HTTPMethod.GET ) - parent_validate_response = self.get_parent_validation() - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) options.update({"methods": [HTTPMethod.GET]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -314,7 +245,7 @@ def post( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, - validate_response: bool = False, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -334,6 +265,7 @@ def post( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -355,8 +287,19 @@ def decorator(func) -> Callable: method=HTTPMethod.POST ) - parent_validate_response = self.get_parent_validation() - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) options.update({"methods": [HTTPMethod.POST]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -380,7 +323,7 @@ def put( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, - validate_response: bool = False, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -400,6 +343,7 @@ def put( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -421,8 +365,19 @@ def decorator(func) -> Callable: method=HTTPMethod.PUT ) - parent_validate_response = self.get_parent_validation() - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) options.update({"methods": [HTTPMethod.PUT]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -446,7 +401,7 @@ def delete( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, - validate_response: bool = False, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -466,6 +421,7 @@ def delete( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -487,8 +443,19 @@ def decorator(func) -> Callable: method=HTTPMethod.DELETE ) - parent_validate_response = self.get_parent_validation() - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) options.update({"methods": [HTTPMethod.DELETE]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -512,7 +479,7 @@ def patch( servers: Optional[list[Server]] = None, openapi_extensions: Optional[dict[str, Any]] = None, doc_ui: bool = True, - validate_response: bool = False, + validate_response: Optional[bool] = None, **options: Any ) -> Callable: """ @@ -532,6 +499,7 @@ def patch( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ def decorator(func) -> Callable: @@ -553,8 +521,19 @@ def decorator(func) -> Callable: method=HTTPMethod.PATCH ) - parent_validate_response = self.get_parent_validation() - view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw, responses=responses, parent_validate_response=parent_validate_response, api_validate_response=validate_response) + _validate_response = validate_response if validate_response is not None else self.get_validate_response() + view_func = self.create_view_func( + func, + header, + cookie, + path, + query, + form, + body, + raw, + responses=responses, + validate_response=_validate_response + ) options.update({"methods": [HTTPMethod.PATCH]}) self._add_url_rule(rule, view_func=view_func, **options) @@ -563,9 +542,7 @@ def decorator(func) -> Callable: return decorator - def get_parent_validation(self): - # NOTE: abp_ vs app_ distinction without a difference in this context? - if hasattr(self, "abp_validate_response"): - return self.abp_validate_response - - return self.app_validate_response + def get_validate_response(self): + if hasattr(self, "validate_response"): + if self.validate_response is not None: + return self.validate_response diff --git a/flask_openapi3/utils.py b/flask_openapi3/utils.py index 205a8cfc..496d0673 100644 --- a/flask_openapi3/utils.py +++ b/flask_openapi3/utils.py @@ -592,6 +592,39 @@ def make_validation_error_response(e: ValidationError) -> FlaskResponse: return response +def run_validate_response(response: Any, responses: Optional[ResponseDict] = None) -> None: + """Validate response""" + if responses is None: + return + + if isinstance(response, tuple): # noqa + _resp, status_code = response[:2] + elif isinstance(response, FlaskResponse): + if response.mimetype != "application/json": + # only application/json + return + _resp, status_code = response.json, response.status_code # noqa + else: + _resp, status_code = response, 200 + + # status_code is http.HTTPStatus + if isinstance(status_code, HTTPStatus): + status_code = status_code.value + + resp_model = responses.get(status_code) + + if resp_model is None: + return + + assert inspect.isclass(resp_model) and \ + issubclass(resp_model, BaseModel), f"{resp_model} is invalid `pydantic.BaseModel`" + + if isinstance(_resp, str): + resp_model.model_validate_json(_resp) + else: + resp_model.model_validate(_resp) + + def parse_rule(rule: str, url_prefix=None) -> str: trail_slash = rule.endswith("/") diff --git a/flask_openapi3/view.py b/flask_openapi3/view.py index b5532f04..418b1f11 100644 --- a/flask_openapi3/view.py +++ b/flask_openapi3/view.py @@ -31,6 +31,7 @@ def __init__( view_responses: Optional[ResponseDict] = None, doc_ui: bool = True, operation_id_callback: Callable = get_operation_id_for_path, + validate_response: Optional[bool] = None ): """ Create a class-based view @@ -44,6 +45,7 @@ def __init__( operation_id_callback: Callback function for custom operation_id generation. Receives name (str), path (str) and method (str) parameters. Defaults to `get_operation_id_for_path` from utils + validate_response: Verify the response body. """ self.url_prefix = url_prefix self.view_tags = view_tags or [] @@ -61,6 +63,8 @@ def __init__( self.tags: list[Tag] = [] self.tag_names: list[str] = [] + self.validate_response = validate_response + def route(self, rule: str): """Decorator for view class""" @@ -131,9 +135,9 @@ def doc( servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. doc_ui: Declares this operation to be shown. Default to True. + validate_response: Verify the response body. """ - # import ipdb; ipdb.set_trace() new_responses = convert_responses_key_to_string(responses or {}) security = security or [] tags = tags + self.view_tags if tags else self.view_tags @@ -191,10 +195,6 @@ def decorator(func): get_responses(combine_responses, self.components_schemas, operation) func.operation = operation - # NOTE: combite_responses instead of responses here? or is above enough? - # func.responses = responses - # func.responses = combite_responses - return func return decorator @@ -217,17 +217,10 @@ def register( for rule, (cls, methods) in self.views.items(): for method in methods: func = getattr(cls, method.lower()) - - if isinstance(func.responses, dict): - responses = func.responses.copy() + if func.validate_response is not None: + _validate_response = func.validate_response else: - responses = func.responses - - validate_response = func.validate_response - - del func.responses - del func.validate_response - + _validate_response = self.validate_response header, cookie, path, query, form, body, raw = parse_parameters(func, doc_ui=False) view_func = app.create_view_func( func, @@ -238,12 +231,10 @@ def register( form, body, raw, - responses=responses, view_class=cls, view_kwargs=view_kwargs, - parent_validate_response=app.app_validate_response, - # NOTE: do we support this at APIView definition time or do we want it at each class/route? or not at all? - api_validate_response=validate_response, + responses=func.responses, + validate_response=_validate_response, ) if url_prefix and self.url_prefix and url_prefix != self.url_prefix: diff --git a/tests/test_validate_responses.py b/tests/test_validate_responses.py index 1cc080b6..f4a6b301 100644 --- a/tests/test_validate_responses.py +++ b/tests/test_validate_responses.py @@ -1,8 +1,6 @@ from __future__ import annotations import pytest - - from pydantic import BaseModel, ValidationError from flask_openapi3 import OpenAPI, APIView @@ -72,24 +70,6 @@ def endpoint_test(body: BaseRequest): _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) -def test_app_config_level_validate_response(request): - """ - Validation turned on at app level - """ - test_app = OpenAPI(request.node.name) - test_app.config["TESTING"] = True - test_app.config["FLASK_OPENAPI_VALIDATE_RESPONSE"] = True - - @test_app.post("/test", responses={201: BadResponse}) - def endpoint_test(body: BaseRequest): - return body.model_dump(), 201 - - with test_app.test_client() as client: - with pytest.raises(ValidationError): - _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) - - - def test_abp_level_no_validate_response(request): """ Validation turned on at app level @@ -147,26 +127,6 @@ def endpoint_test(body: BaseRequest): _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) -def test_abp_config_level_validate_response(request): - """ - Validation turned on at app level - """ - test_app = OpenAPI(request.node.name) - test_app.config["TESTING"] = True - test_app.config["FLASK_OPENAPI_VALIDATE_RESPONSE"] = True - test_abp = APIBlueprint("abp", __name__) - - @test_abp.post("/test", responses={201: BadResponse}) - def endpoint_test(body: BaseRequest): - return body.model_dump(), 201 - - test_app.register_api(test_abp) - - with test_app.test_client() as client: - with pytest.raises(ValidationError): - _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) - - def test_apiview_no_validate_response(request): """ Response validation defaults to no validation @@ -231,30 +191,3 @@ def post(self, body: BaseRequest): with test_app.test_client() as client: with pytest.raises(ValidationError): _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) - - -# NOTE: can't register APIView on APIBlueprint - missing feature? -# def test_apiview_abp_level_validate_response(request): -# """ -# Validation turned on at app level -# """ -# -# test_app = OpenAPI(request.node.name, validate_response=True) -# test_abp = APIBlueprint("abp", __name__) -# test_app.config["TESTING"] = True -# test_api_view = APIView() -# -# @test_api_view.route("/test") -# class TestAPI: -# -# @test_api_view.doc(responses={201: BadResponse}) -# def post(self, body: BaseRequest): -# return body.model_dump(), 201 -# -# test_abp.register_api_view(test_api_view) -# test_app.register_api(test_abp) -# -# -# with test_app.test_client() as client: -# with pytest.raises(ValidationError): -# _ = client.post("/test", json={"test_int": 1, "test_str": "s"}) From 9c5da0578106fccd3e86f5281d8786d57b86ae0b Mon Sep 17 00:00:00 2001 From: luolingchun Date: Thu, 22 May 2025 10:10:23 +0800 Subject: [PATCH 3/3] Allow to override `validate_response_callback` --- docs/Usage/Response.md | 13 +++++++++++++ flask_openapi3/models/components.py | 4 ++-- flask_openapi3/openapi.py | 4 ++++ flask_openapi3/scaffold.py | 9 +++++---- flask_openapi3/utils.py | 10 ++++++---- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/Usage/Response.md b/docs/Usage/Response.md index 3efa0ba5..023e9fa9 100644 --- a/docs/Usage/Response.md +++ b/docs/Usage/Response.md @@ -80,6 +80,19 @@ def endpoint_test(body: BaseRequest): ... ``` +You can also customize the default behavior of response validation by using a custom `validate_response_callback`. + +```python + +def validate_response_callback(response: Any, responses: Optional[ResponseDict] = None) -> Any: + + # do something + + return response + +app = OpenAPI(__name__, validate_response=True, validate_response_callback=validate_response_callback) +``` + ## More information about OpenAPI responses - [OpenAPI Responses Object](https://spec.openapis.org/oas/v3.1.0#responses-object), it includes the Response Object. diff --git a/flask_openapi3/models/components.py b/flask_openapi3/models/components.py index 1bdbeb15..c498b52e 100644 --- a/flask_openapi3/models/components.py +++ b/flask_openapi3/models/components.py @@ -3,7 +3,7 @@ # @Time : 2023/7/4 9:36 from typing import Optional, Union, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel from .callback import Callback from .example import Example @@ -23,7 +23,7 @@ class Components(BaseModel): https://spec.openapis.org/oas/v3.1.0#components-object """ - schemas: Optional[dict[str, Union[Reference, Schema]]] = Field(None) + schemas: Optional[dict[str, Union[Reference, Schema]]] = None responses: Optional[dict[str, Union[Response, Reference]]] = None parameters: Optional[dict[str, Union[Parameter, Reference]]] = None examples: Optional[dict[str, Union[Example, Reference]]] = None diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index e48f21a7..91b1a076 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -42,6 +42,7 @@ from .utils import parse_and_store_tags from .utils import parse_method from .utils import parse_parameters +from .utils import run_validate_response from .view import APIView @@ -64,6 +65,7 @@ def __init__( doc_prefix: str = "/openapi", doc_url: str = "/openapi.json", validate_response: Optional[bool] = None, + validate_response_callback: Callable = run_validate_response, **kwargs: Any ) -> None: """ @@ -97,6 +99,7 @@ def __init__( doc_url: URL for accessing the OpenAPI specification document in JSON format. Defaults to "/openapi.json". validate_response: Verify the response body. + validate_response_callback: Validation and return response. **kwargs: Additional kwargs to be passed to Flask. """ super(OpenAPI, self).__init__(import_name, **kwargs) @@ -148,6 +151,7 @@ def __init__( # Verify the response body self.validate_response = validate_response + self.validate_response_callback = validate_response_callback # Initialize specification JSON self.spec_json: dict = {} diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index da43469d..728c8192 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -15,7 +15,6 @@ from .types import ParametersTuple from .types import ResponseDict from .utils import HTTPMethod -from .utils import run_validate_response class APIScaffold: @@ -105,7 +104,8 @@ async def view_func(**kwargs) -> FlaskResponse: _validate_response = validate_response if _validate_response and responses: - run_validate_response(response, responses) + validate_response_callback = getattr(current_app, "validate_response_callback") + return validate_response_callback(response, responses) return response else: @@ -141,9 +141,10 @@ def view_func(**kwargs) -> FlaskResponse: _validate_response = validate_response else: _validate_response = validate_response - + if _validate_response and responses: - run_validate_response(response, responses) + validate_response_callback = getattr(current_app, "validate_response_callback") + return validate_response_callback(response, responses) return response diff --git a/flask_openapi3/utils.py b/flask_openapi3/utils.py index 496d0673..be9de8ba 100644 --- a/flask_openapi3/utils.py +++ b/flask_openapi3/utils.py @@ -592,17 +592,17 @@ def make_validation_error_response(e: ValidationError) -> FlaskResponse: return response -def run_validate_response(response: Any, responses: Optional[ResponseDict] = None) -> None: +def run_validate_response(response: Any, responses: Optional[ResponseDict] = None) -> Any: """Validate response""" if responses is None: - return + return response if isinstance(response, tuple): # noqa _resp, status_code = response[:2] elif isinstance(response, FlaskResponse): if response.mimetype != "application/json": # only application/json - return + return response _resp, status_code = response.json, response.status_code # noqa else: _resp, status_code = response, 200 @@ -614,7 +614,7 @@ def run_validate_response(response: Any, responses: Optional[ResponseDict] = Non resp_model = responses.get(status_code) if resp_model is None: - return + return response assert inspect.isclass(resp_model) and \ issubclass(resp_model, BaseModel), f"{resp_model} is invalid `pydantic.BaseModel`" @@ -624,6 +624,8 @@ def run_validate_response(response: Any, responses: Optional[ResponseDict] = Non else: resp_model.model_validate(_resp) + return response + def parse_rule(rule: str, url_prefix=None) -> str: trail_slash = rule.endswith("/")