diff --git a/flask_openapi3/request.py b/flask_openapi3/request.py index f6c1d39e..a83c6366 100644 --- a/flask_openapi3/request.py +++ b/flask_openapi3/request.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- # @Author : llc # @Time : 2022/4/1 16:54 +from __future__ import annotations + import json from json import JSONDecodeError from typing import Any, Type, Optional -from flask import request, current_app, abort +from flask import request from pydantic import ValidationError, BaseModel from pydantic.fields import FieldInfo from werkzeug.datastructures.structures import MultiDict @@ -155,7 +157,7 @@ def _validate_request( body: Optional[Type[BaseModel]] = None, raw: Optional[Type[BaseModel]] = None, path_kwargs: Optional[dict[Any, Any]] = None -) -> dict: +) -> tuple[dict, Any | None]: """ Validate requests and responses. @@ -170,13 +172,12 @@ def _validate_request( Returns: dict: Request kwargs. - - Raises: - ValidationError: If validation fails. + error: ValidationError """ # Dictionary to store func kwargs func_kwargs: dict = {} + error = None try: # Validate header, cookie, path, and query parameters @@ -195,8 +196,6 @@ def _validate_request( if raw: func_kwargs["raw"] = request except ValidationError as e: - # Create a response with validation error details - validation_error_callback = getattr(current_app, "validation_error_callback") - abort(validation_error_callback(e)) + error = e - return func_kwargs + return func_kwargs, error diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 0e38e5ac..59e95bd4 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- # @Author : llc # @Time : 2022/8/30 9:40 +from __future__ import annotations + import inspect from functools import wraps from typing import Callable, Optional, Any -from flask.wrappers import Response as FlaskResponse +from flask import abort, current_app from .models import ExternalDocumentation from .models import Server @@ -35,10 +37,10 @@ def _collect_openapi_info( doc_ui: bool = True, method: str = HTTPMethod.GET ) -> ParametersTuple: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def register_api(self, api) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _add_url_rule( self, @@ -48,7 +50,7 @@ def _add_url_rule( provide_automatic_options=None, **options, ) -> None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover @staticmethod def create_view_func( @@ -66,8 +68,8 @@ def create_view_func( is_coroutine_function = inspect.iscoroutinefunction(func) if is_coroutine_function: @wraps(func) - async def view_func(**kwargs) -> FlaskResponse: - func_kwargs = _validate_request( + async def view_func(**kwargs) -> Any | None: + func_kwargs, error = _validate_request( header=header, cookie=cookie, path=path, @@ -77,23 +79,30 @@ async def view_func(**kwargs) -> FlaskResponse: raw=raw, path_kwargs=kwargs ) - - # handle async request - if view_class: - signature = inspect.signature(view_class.__init__) - parameters = signature.parameters - if parameters.get("view_kwargs"): - view_object = view_class(view_kwargs=view_kwargs) + try: + # handle async request + if view_class: + signature = inspect.signature(view_class.__init__) + parameters = signature.parameters + if parameters.get("view_kwargs"): + view_object = view_class(view_kwargs=view_kwargs) + else: + view_object = view_class() + response = await func(view_object, **func_kwargs) + else: + response = await func(**func_kwargs) + return response + except TypeError as e: + if error: + # Create a response with validation error details + validation_error_callback = getattr(current_app, "validation_error_callback") + abort(validation_error_callback(error)) else: - view_object = view_class() - response = await func(view_object, **func_kwargs) - else: - response = await func(**func_kwargs) - return response + raise e else: @wraps(func) - def view_func(**kwargs) -> FlaskResponse: - func_kwargs = _validate_request( + def view_func(**kwargs) -> Any | None: + func_kwargs, error = _validate_request( header=header, cookie=cookie, path=path, @@ -103,19 +112,26 @@ def view_func(**kwargs) -> FlaskResponse: raw=raw, path_kwargs=kwargs ) - - # handle request - if view_class: - signature = inspect.signature(view_class.__init__) - parameters = signature.parameters - if parameters.get("view_kwargs"): - view_object = view_class(view_kwargs=view_kwargs) + try: + # handle request + if view_class: + signature = inspect.signature(view_class.__init__) + parameters = signature.parameters + if parameters.get("view_kwargs"): + view_object = view_class(view_kwargs=view_kwargs) + else: + view_object = view_class() + response = func(view_object, **func_kwargs) + else: + response = func(**func_kwargs) + return response + except TypeError as e: + if error: + # Create a response with validation error details + validation_error_callback = getattr(current_app, "validation_error_callback") + abort(validation_error_callback(error)) else: - view_object = view_class() - response = func(view_object, **func_kwargs) - else: - response = func(**func_kwargs) - return response + raise e if not hasattr(func, "view"): func.view = view_func