Skip to content

Commit

Permalink
Fix DRF request and response types
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon committed Aug 28, 2023
1 parent ec0421a commit b151e58
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
7 changes: 3 additions & 4 deletions apitally/django_rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@


if TYPE_CHECKING:
from django.http import HttpRequest
from rest_framework.request import Request
from rest_framework.views import APIView


__all__ = ["ApitallyMiddleware", "HasAPIKey", "KeyInfo"]


class HasAPIKey(BasePermission): # type: ignore[misc]
def has_permission(self, request: HttpRequest, view: APIView) -> bool:
def has_permission(self, request: Request, view: APIView) -> bool:
authorization = request.headers.get("Authorization")
if not authorization:
return False
Expand All @@ -30,6 +30,5 @@ def has_permission(self, request: HttpRequest, view: APIView) -> bool:
return False
if hasattr(view, "required_scopes") and not key_info.check_scopes(view.required_scopes):
return False
if not hasattr(request, "key_info"):
setattr(request, "key_info", key_info)
request.auth = key_info
return True
17 changes: 9 additions & 8 deletions tests/django_rest_framework_urls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.http import HttpRequest, HttpResponse
from django.urls import path
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView

from apitally.django_rest_framework import HasAPIKey
Expand All @@ -8,30 +9,30 @@
class FooView(APIView):
permission_classes = [HasAPIKey]

def get(self, request: HttpRequest) -> HttpResponse:
return HttpResponse("foo")
def get(self, request: Request) -> Response:
return Response("foo")


class FooBarView(APIView):
permission_classes = [HasAPIKey]
required_scopes = ["foo"]

def get(self, request: HttpRequest, bar: int) -> HttpResponse:
return HttpResponse(f"foo: {bar}")
def get(self, request: Request, bar: int) -> Response:
return Response(f"foo: {bar}")


class BarView(APIView):
permission_classes = [HasAPIKey]
required_scopes = ["bar"]

def post(self, request: HttpRequest) -> HttpResponse:
return HttpResponse("bar")
def post(self, request: Request) -> Response:
return Response("bar")


class BazView(APIView):
permission_classes = [HasAPIKey]

def put(self, request: HttpRequest) -> HttpResponse:
def put(self, request: Request) -> Response:
raise ValueError("baz")


Expand Down

0 comments on commit b151e58

Please sign in to comment.