diff --git a/flask_rest_jsonapi/api.py b/flask_rest_jsonapi/api.py index a385256..5b6e459 100644 --- a/flask_rest_jsonapi/api.py +++ b/flask_rest_jsonapi/api.py @@ -66,27 +66,41 @@ def route(self, resource, view, *urls, **kwargs): :param list urls: the urls of the view :param dict kwargs: additional options of the route """ - resource.view = view url_rule_options = kwargs.get('url_rule_options') or dict() - view_func = resource.as_view(view) - + # Find the parent object for this route, and also the correct endpoint name if 'blueprint' in kwargs: - resource.view = '.'.join([kwargs['blueprint'].name, resource.view]) - for url in urls: - kwargs['blueprint'].add_url_rule(url, view_func=view_func, **url_rule_options) + view_name = '.'.join([kwargs['blueprint'].name, view]) + blueprint = kwargs['blueprint'] elif self.blueprint is not None: - resource.view = '.'.join([self.blueprint.name, resource.view]) - for url in urls: - self.blueprint.add_url_rule(url, view_func=view_func, **url_rule_options) + view_name = '.'.join([self.blueprint.name, view]) + blueprint = self.blueprint elif self.app is not None: + view_name = view + blueprint = self.app + else: + view_name = view + blueprint = None + + # Give the resource class a default endpoint. This will be overwritten if route() + # is called again for this resource + resource.view = view_name + + view_func = resource.as_view( + view, + endpoint=view_name, + ) + + if blueprint is not None: for url in urls: - self.app.add_url_rule(url, view_func=view_func, **url_rule_options) + blueprint.add_url_rule(url, view_func=view_func, **url_rule_options) else: - self.resources.append({'resource': resource, - 'view': view, - 'urls': urls, - 'url_rule_options': url_rule_options}) + self.resources.append({ + 'resource': resource, + 'view': view, + 'urls': urls, + 'url_rule_options': url_rule_options + }) self.resource_registry.append(resource) diff --git a/flask_rest_jsonapi/resource.py b/flask_rest_jsonapi/resource.py index 34ceccf..d940306 100644 --- a/flask_rest_jsonapi/resource.py +++ b/flask_rest_jsonapi/resource.py @@ -52,13 +52,23 @@ def __new__(cls, name, bases, d): class Resource(MethodView): """Base resource class""" - def __new__(cls): + def __new__(cls, **kwargs): """Constructor of a resource instance""" if hasattr(cls, '_data_layer'): cls._data_layer.resource = cls return super(Resource, cls).__new__(cls) + def __init__(self, endpoint=None): + # By default we assign each Resource class with a view/endpoint. However if the + # same resource is used for multiple routes, the endpoint will be overwritten. + # This ensures the Resource instances have the correct view + if endpoint is not None: + self.view = endpoint + + def parse_request(self): + return self.request_parsers[request.content_type](request) + @jsonapi_exception_formatter def dispatch_request(self, *args, **kwargs): """Logic of how to handle a request""" diff --git a/tests/test_sqlalchemy_data_layer.py b/tests/test_sqlalchemy_data_layer.py index ea40923..08e5cf0 100644 --- a/tests/test_sqlalchemy_data_layer.py +++ b/tests/test_sqlalchemy_data_layer.py @@ -1800,6 +1800,24 @@ def test_api_resources(app, person_list): api.route(person_list, 'person_list2', '/persons', '/person_list') api.init_app(app) +def test_api_resources_multiple_route(app, person_list): + """ + If we use the same resource twice, each instance of that resource should have the + correct endpoint + """ + api = Api() + + class DummyResource(ResourceDetail): + def get(self): + return self.view + + api.route(DummyResource, 'endpoint1', '/url1') + api.route(DummyResource, 'endpoint2', '/url2') + api.init_app(app) + + with app.test_client() as client: + assert client.get('/url1', content_type='application/vnd.api+json').json == 'endpoint1' + assert client.get('/url2', content_type='application/vnd.api+json').json == 'endpoint2' def test_relationship_containing_hyphens(client, register_routes, person_computers, computer_schema, person): response = client.get('/persons/{}/relationships/computers-owned'.format(person.person_id), content_type='application/vnd.api+json')