diff --git a/README.rst b/README.rst index 77391e7..d60e686 100644 --- a/README.rst +++ b/README.rst @@ -64,10 +64,10 @@ Quickstart def get(self, pet_id): return Pet.query.filter(Pet.id == pet_id).one() - @use_kwargs(PetSchema) + @use_args(PetSchema) @marshal_with(PetSchema, code=201) - def post(self, **kwargs): - return Pet(**kwargs) + def post(self, data): + return Pet(**data) @use_kwargs(PetSchema) @marshal_with(PetSchema) diff --git a/docs/usage.rst b/docs/usage.rst index bce0f54..ffe52db 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -6,20 +6,26 @@ Usage Decorators ---------- -Use the :func:`use_kwargs ` and :func:`marshal_with ` decorators on functions, methods, or classes to declare request parsing and response marshalling behavior, respectively. +Use the :func:`use_args `, :func:`use_kwargs ` and :func:`marshal_with ` decorators on functions, methods, or classes to declare request parsing and response marshalling behavior, respectively. .. code-block:: python import flask from webargs import fields - from flask_apispec import use_kwargs, marshal_with + from flask_apispec import use_args, use_kwargs, marshal_with from .models import Pet from .schemas import PetSchema app = flask.Flask(__name__) - @app.route('/pets') + @app.route('/pets', methods=['POST']) + @use_args(PetSchema) + @marshal_with(PetSchema) + def create_pet(data): + return Pet(**data) + + @app.route('/pets', methods=['GET']) @use_kwargs({'species': fields.Str()}) @marshal_with(PetSchema(many=True)) def list_pets(**kwargs): diff --git a/flask_apispec/__init__.py b/flask_apispec/__init__.py index 38e478a..423e07b 100644 --- a/flask_apispec/__init__.py +++ b/flask_apispec/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from flask_apispec.views import ResourceMeta, MethodResource -from flask_apispec.annotations import doc, wrap_with, use_kwargs, marshal_with +from flask_apispec.annotations import doc, wrap_with, use_args, use_kwargs, marshal_with from flask_apispec.extension import FlaskApiSpec from flask_apispec.utils import Ref @@ -8,6 +8,7 @@ __all__ = [ 'doc', 'wrap_with', + 'use_args', 'use_kwargs', 'marshal_with', 'ResourceMeta', diff --git a/flask_apispec/annotations.py b/flask_apispec/annotations.py index 6e1992c..c677595 100644 --- a/flask_apispec/annotations.py +++ b/flask_apispec/annotations.py @@ -5,7 +5,45 @@ from flask_apispec import utils from flask_apispec.wrapper import Wrapper -def use_kwargs(args, locations=None, inherit=None, apply=None, **kwargs): + +def use_args(argmap, locations=None, inherit=None, apply=None, **kwargs): + """Inject positional arguments from the specified webargs arguments into the + decorated view function. + + Usage: + + .. code-block:: python + + from marshmallow import fields, Schema + + class PetSchema(Schema): + name = fields.Str() + + @use_args(PetSchema) + def create_pet(data): + pet = Pet(**data) + return session.add(pet) + + :param argmap: Mapping of argument names to :class:`Field ` + objects, :class:`Schema `, or a callable which accepts a + request and returns a :class:`Schema ` + :param locations: Default request locations to parse + :param inherit: Inherit args from parent classes + :param apply: Parse request with specified args + """ + kwargs.update({'locations': locations}) + + def wrapper(func): + options = { + 'argmap': argmap, + 'kwargs': kwargs + } + annotate(func, 'args', [options], inherit=inherit, apply=apply) + return activate(func) + + return wrapper + +def use_kwargs(argmap, locations=None, inherit=None, apply=None, **kwargs): """Inject keyword arguments from the specified webargs arguments into the decorated view function. @@ -19,7 +57,7 @@ def use_kwargs(args, locations=None, inherit=None, apply=None, **kwargs): def get_pets(**kwargs): return Pet.query.filter_by(**kwargs).all() - :param args: Mapping of argument names to :class:`Field ` + :param argmap: Mapping of argument names to :class:`Field ` objects, :class:`Schema `, or a callable which accepts a request and returns a :class:`Schema ` :param locations: Default request locations to parse @@ -30,14 +68,13 @@ def get_pets(**kwargs): def wrapper(func): options = { - 'args': args, + 'argmap': argmap, 'kwargs': kwargs, } - annotate(func, 'args', [options], inherit=inherit, apply=apply) + annotate(func, 'kwargs', [options], inherit=inherit, apply=apply) return activate(func) return wrapper - def marshal_with(schema, code='default', description='', inherit=None, apply=None): """Marshal the return value of the decorated view function using the specified schema. diff --git a/flask_apispec/apidoc.py b/flask_apispec/apidoc.py index de539cd..5178ecc 100644 --- a/flask_apispec/apidoc.py +++ b/flask_apispec/apidoc.py @@ -68,9 +68,9 @@ def get_parent(self, view): def get_parameters(self, rule, view, docs, parent=None): openapi = self.marshmallow_plugin.openapi - annotation = resolve_annotations(view, 'args', parent) + annotation = resolve_annotations(view, 'kwargs', parent) args = merge_recursive(annotation.options) - schema = args.get('args', {}) + schema = args.get('argmap', {}) if is_instance_or_subclass(schema, Schema): converter = openapi.schema2parameters elif callable(schema): diff --git a/flask_apispec/views.py b/flask_apispec/views.py index 06b4fd8..2bc5903 100644 --- a/flask_apispec/views.py +++ b/flask_apispec/views.py @@ -7,7 +7,7 @@ def inherit(child, parents): child.__apispec__ = child.__dict__.get('__apispec__', {}) - for key in ['args', 'schemas', 'docs']: + for key in ['kwargs', 'schemas', 'docs']: child.__apispec__.setdefault(key, []).extend( annotation for parent in parents diff --git a/flask_apispec/wrapper.py b/flask_apispec/wrapper.py index f04e6d0..f514019 100644 --- a/flask_apispec/wrapper.py +++ b/flask_apispec/wrapper.py @@ -4,6 +4,7 @@ except ImportError: # Python 2 from collections import Mapping +from types import MethodType import flask import marshmallow as ma @@ -37,21 +38,32 @@ def __call__(self, *args, **kwargs): return self.marshal_result(unpacked, status_code) def call_view(self, *args, **kwargs): + view_fn = self.func config = flask.current_app.config parser = config.get('APISPEC_WEBARGS_PARSER', flaskparser.parser) + # Delegate webargs.use_args annotations annotation = utils.resolve_annotations(self.func, 'args', self.instance) if annotation.apply is not False: for option in annotation.options: - schema = utils.resolve_schema(option['args'], request=flask.request) - parsed = parser.parse(schema, locations=option['kwargs']['locations']) + schema = utils.resolve_schema(option['argmap'], request=flask.request) + view_fn = parser.use_args(schema, **option['kwargs'])(view_fn) + # Delegate webargs.use_kwargs annotations + annotation = utils.resolve_annotations(self.func, 'kwargs', self.instance) + if annotation.apply is not False: + for option in annotation.options: + schema = utils.resolve_schema(option['argmap'], request=flask.request) if getattr(schema, 'many', False): - args += tuple(parsed) - elif isinstance(parsed, Mapping): - kwargs.update(parsed) - else: - args += (parsed, ) - - return self.func(*args, **kwargs) + raise Exception("@use_kwargs cannot be used with a with a " + "'many=True' schema, as it must deserialize " + "to a dict") + elif isinstance(schema, ma.Schema): + # Spy the post_load to provide a more informative error + # if it doesn't return a Mapping + post_load_fns = post_load_fn_names(schema) + for post_load_fn_name in post_load_fns: + spy_post_load(schema, post_load_fn_name) + view_fn = parser.use_kwargs(schema, **option['kwargs'])(view_fn) + return view_fn(*args, **kwargs) def marshal_result(self, unpacked, status_code): config = flask.current_app.config @@ -78,3 +90,46 @@ def format_output(values): while values[-1] is None: values = values[:-1] return values if len(values) > 1 else values[0] + +def post_load_fn_names(schema): + fn_names = [] + if hasattr(schema, '_hooks'): + # Marshmallow >=3 + hooks = getattr(schema, '_hooks') + for key in ((ma.decorators.POST_LOAD, True), + (ma.decorators.POST_LOAD, False)): + if key in hooks: + fn_names.append(*hooks[key]) + else: + # Marshmallow <= 2 + processors = getattr(schema, '__processors__') + for key in ((ma.decorators.POST_LOAD, True), + (ma.decorators.POST_LOAD, False)): + if key in processors: + fn_names.append(*processors[key]) + return fn_names + +def spy_post_load(schema, post_load_fn_name): + processor = getattr(schema, post_load_fn_name) + + def _spy_processor(_self, *args, **kwargs): + rv = processor(*args, **kwargs) + if not isinstance(rv, Mapping): + raise Exception("The @use_kwargs decorator can only use Schemas that " + "return dicts, but the @post_load-annotated method " + "'{schema_type}.{post_load_fn_name}' returned: {rv}" + .format(schema_type=type(schema), + post_load_fn_name=post_load_fn_name, + rv=rv)) + return rv + + for attr in ( + # Marshmallow <= 2.x + '__marshmallow_tags__', + '__marshmallow_kwargs__', + # Marshmallow >= 3.x + '__marshmallow_hook__' + ): + if hasattr(processor, attr): + setattr(_spy_processor, attr, getattr(processor, attr)) + setattr(schema, post_load_fn_name, MethodType(_spy_processor, schema)) diff --git a/tests/test_views.py b/tests/test_views.py index e7323fb..66232b2 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,13 +4,96 @@ from flask import make_response from marshmallow import fields, Schema, post_load +import pytest from flask_apispec.utils import Ref from flask_apispec.views import MethodResource -from flask_apispec import doc, use_kwargs, marshal_with +from flask_apispec import doc, use_args, use_kwargs, marshal_with class TestFunctionViews: + def test_use_args(self, app, client): + @app.route('/') + @use_args({'name': fields.Str()}) + def view(*args): + return args + res = client.get('/', {'name': 'freddie'}) + assert res.json == {'name': 'freddie'} + + def test_use_args_schema(self, app, client): + class ArgSchema(Schema): + name = fields.Str() + + @app.route('/') + @use_args(ArgSchema) + def view(*args): + return args + res = client.get('/', {'name': 'freddie'}) + assert res.json == {'name': 'freddie'} + + def test_use_args_schema_with_post_load(self, app, client): + class User: + def __init__(self, name): + self.name = name + + def update(self, name): + self.name = name + + class ArgSchema(Schema): + name = fields.Str() + + @post_load + def make_object(self, data, **kwargs): + return User(**data) + + @app.route('/', methods=('POST', )) + @use_args(ArgSchema()) + def view(user): + assert isinstance(user, User) + return {'name': user.name} + + data = {'name': 'freddie'} + res = client.post('/', data) + assert res.json == data + + def test_use_args_schema_many(self, app, client): + class ArgSchema(Schema): + name = fields.Str() + + @app.route('/', methods=('POST',)) + @use_args(ArgSchema(many=True), locations=('json',)) + def view(*args): + return args + data = [{'name': 'freddie'}, {'name': 'john'}] + res = client.post('/', json.dumps(data), content_type='application/json') + assert res.json == data + + def test_use_args_multiple(self, app, client): + @app.route('/') + @use_args({'name': fields.Str()}) + @use_args({'instrument': fields.Str()}) + def view(*args): + return list(args) + res = client.get('/', {'name': 'freddie', 'instrument': 'vocals'}) + assert res.json == [{'instrument': 'vocals'}, {'name': 'freddie'}] + + def test_use_args_callable_as_schema(self, app, client): + def schema_factory(request): + assert request.method == 'GET' + assert request.path == '/' + + class ArgSchema(Schema): + name = fields.Str() + + return ArgSchema + + @app.route('/') + @use_args(schema_factory) + def view(*args): + return args + res = client.get('/', {'name': 'freddie'}) + assert res.json == {'name': 'freddie'} + def test_use_kwargs(self, app, client): @app.route('/') @use_kwargs({'name': fields.Str()}) @@ -43,11 +126,12 @@ class ArgSchema(Schema): @post_load def make_object(self, data, **kwargs): - return User(**data) + return {"user": User(**data)} @app.route('/', methods=('POST', )) @use_kwargs(ArgSchema()) - def view(user): + def view(**kwargs): + user = kwargs["user"] assert isinstance(user, User) return {'name': user.name} @@ -55,6 +139,32 @@ def view(user): res = client.post('/', data) assert res.json == data + def test_use_kwargs_schema_with_post_load_schema(self, app, client): + class User: + def __init__(self, name): + self.name = name + + def update(self, name): + self.name = name + + class ArgSchema(Schema): + name = fields.Str() + + @post_load + def make_object(self, data, **kwargs): + return User(**data) + + @app.route('/', methods=('POST', )) + @use_kwargs(ArgSchema()) + def view(user): + assert isinstance(user, User) + return {'name': user.name} + + data = {'name': 'freddie'} + with pytest.raises(Exception, match=r"The @use_kwargs decorator can only use Schemas that return " + r"dicts, but the @post_load-annotated method '.*' returned: .*"): + client.post('/', data) + def test_use_kwargs_schema_many(self, app, client): class ArgSchema(Schema): name = fields.Str() @@ -64,8 +174,10 @@ class ArgSchema(Schema): def view(*args): return list(args) data = [{'name': 'freddie'}, {'name': 'john'}] - res = client.post('/', json.dumps(data), content_type='application/json') - assert res.json == data + + with pytest.raises(Exception, match="@use_kwargs cannot be used with a with a 'many=True' " + "schema, as it must deserialize to a dict"): + client.post('/', json.dumps(data), content_type='application/json', expect_errors=True) def test_use_kwargs_multiple(self, app, client): @app.route('/')