From 74e96240ce7668bb94e47ba383f1e6745c5de37a Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Mon, 11 Nov 2019 17:08:35 +1100 Subject: [PATCH 1/2] Allow functions for get_schema_kwargs --- flask_rest_jsonapi/resource.py | 37 +++++++++++++++++++--- tests/test_sqlalchemy_data_layer.py | 49 ++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/flask_rest_jsonapi/resource.py b/flask_rest_jsonapi/resource.py index 34ceccf..59fd8ee 100644 --- a/flask_rest_jsonapi/resource.py +++ b/flask_rest_jsonapi/resource.py @@ -59,6 +59,31 @@ def __new__(cls): return super(Resource, cls).__new__(cls) + def _access_kwargs(self, name, args, kwargs): + """ + Gets the kwargs dictionary with the provided name. This can be implemented as + a dictionary *or* a function, so we have to handle both possibilities + """ + # Access the field + val = getattr(self, name, dict()) + + if callable(val): + # If it's a function, call it and validate its result + schema_kwargs = val(args, kwargs) + if not isinstance(schema_kwargs, dict): + raise TypeError( + 'The return value of the "{}" function must be a dictionary of kwargs' + ) + else: + # If it's a dictionary, use it directly + schema_kwargs = val + if not isinstance(schema_kwargs, dict): + raise TypeError( + 'The value of the "{}" class variable must be a dictionary of kwargs' + ) + + return schema_kwargs + @jsonapi_exception_formatter def dispatch_request(self, *args, **kwargs): """Logic of how to handle a request""" @@ -118,7 +143,9 @@ def get(self, *args, **kwargs): objects_count, objects = self.get_collection(qs, kwargs) - schema_kwargs = getattr(self, 'get_schema_kwargs', dict()) + # get_schema_kwargs can be a class variable or a function + schema_kwargs = self._access_kwargs('get_schema_kwargs', args, kwargs) + schema_kwargs.update() schema_kwargs.update({'many': True}) self.before_marshmallow(args, kwargs) @@ -149,8 +176,9 @@ def post(self, *args, **kwargs): qs = QSManager(request.args, self.schema) + schema_kwargs = self._access_kwargs('post_schema_kwargs', args, kwargs) schema = compute_schema(self.schema, - getattr(self, 'post_schema_kwargs', dict()), + schema_kwargs, qs, qs.include) @@ -230,8 +258,9 @@ def get(self, *args, **kwargs): self.before_marshmallow(args, kwargs) + schema_kwargs = self._access_kwargs('get_schema_kwargs', args, kwargs) schema = compute_schema(self.schema, - getattr(self, 'get_schema_kwargs', dict()), + schema_kwargs, qs, qs.include) @@ -247,7 +276,7 @@ def patch(self, *args, **kwargs): json_data = request.get_json() or {} qs = QSManager(request.args, self.schema) - schema_kwargs = getattr(self, 'patch_schema_kwargs', dict()) + schema_kwargs = self._access_kwargs('patch_schema_kwargs', args, kwargs) schema_kwargs.update({'partial': True}) self.before_marshmallow(args, kwargs) diff --git a/tests/test_sqlalchemy_data_layer.py b/tests/test_sqlalchemy_data_layer.py index ea40923..62f868c 100644 --- a/tests/test_sqlalchemy_data_layer.py +++ b/tests/test_sqlalchemy_data_layer.py @@ -6,7 +6,7 @@ from sqlalchemy import create_engine, Column, Integer, DateTime, String, ForeignKey from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.ext.declarative import declarative_base -from flask import Blueprint, make_response, json +from flask import Blueprint, make_response, json, Flask from marshmallow_jsonapi.flask import Schema, Relationship from marshmallow import Schema as MarshmallowSchema from marshmallow_jsonapi import fields @@ -457,6 +457,7 @@ def register_routes(client, app, api_blueprint, person_list, person_detail, pers api.route(string_json_attribute_person_detail, 'string_json_attribute_person_detail', '/string_json_attribute_persons/') api.init_app(app) + return api @pytest.fixture(scope="module") @@ -647,6 +648,52 @@ def test_get_list_disable_pagination(client, register_routes): response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json') assert response.status_code == 200 +def test_get_list_class_kwargs(session, person, person_schema, person_model, computer_list): + class PersonDetail(ResourceDetail): + schema = person_schema + data_layer = { + 'model': person_model, + 'session': session, + 'url_field': 'person_id' + } + + get_schema_kwargs = dict( + exclude=['name'] + ) + + app = Flask('test') + api = Api(app=app) + api.route(PersonDetail, 'api.person_detail', '/persons/') + api.route(computer_list, 'api.computer_list', '/computers', '/persons//computers') + api.init_app(app) + + ret = app.test_client().get('/persons/{}'.format(person.person_id)) + + assert 'name' not in ret.json['data']['attributes'] + +def test_get_list_func_kwargs(session, person, person_schema, person_model, computer_list): + class PersonDetail(ResourceDetail): + schema = person_schema + data_layer = { + 'model': person_model, + 'session': session, + 'url_field': 'person_id' + } + + def get_schema_kwargs(self, args, kwargs): + return dict( + exclude=['name'] + ) + + app = Flask('test') + api = Api(app=app) + api.route(PersonDetail, 'api.person_detail', '/persons/') + api.route(computer_list, 'api.computer_list', '/computers', '/persons//computers') + api.init_app(app) + + ret = app.test_client().get('/persons/{}'.format(person.person_id)) + + assert 'name' not in ret.json['data']['attributes'] def test_head_list(client, register_routes): with client: From 53886869811fce7cc1713f42b7142dfec148048e Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Mon, 11 Nov 2019 17:09:34 +1100 Subject: [PATCH 2/2] Comments for tests --- tests/test_sqlalchemy_data_layer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_sqlalchemy_data_layer.py b/tests/test_sqlalchemy_data_layer.py index 62f868c..b3b98f8 100644 --- a/tests/test_sqlalchemy_data_layer.py +++ b/tests/test_sqlalchemy_data_layer.py @@ -649,6 +649,9 @@ def test_get_list_disable_pagination(client, register_routes): assert response.status_code == 200 def test_get_list_class_kwargs(session, person, person_schema, person_model, computer_list): + """ + Test a resource that defines its get_schema_kwargs as a dictionary class variable + """ class PersonDetail(ResourceDetail): schema = person_schema data_layer = { @@ -672,6 +675,9 @@ class PersonDetail(ResourceDetail): assert 'name' not in ret.json['data']['attributes'] def test_get_list_func_kwargs(session, person, person_schema, person_model, computer_list): + """ + Test a resource that defines its get_schema_kwargs as a function + """ class PersonDetail(ResourceDetail): schema = person_schema data_layer = {