-
Notifications
You must be signed in to change notification settings - Fork 191
Added support and test for async #300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c72d16d
aef86d2
7ae7c9f
41bf6a8
dda2c56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from django_ratelimit import ALL, UNSAFE | ||
|
||
|
||
__all__ = ['is_ratelimited', 'get_usage'] | ||
__all__ = ['is_ratelimited', 'ais_ratelimited', 'get_usage', 'aget_usage'] | ||
|
||
_PERIODS = { | ||
's': 1, | ||
|
@@ -156,9 +156,30 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None, | |
|
||
return usage['should_limit'] | ||
|
||
async def ais_ratelimited(request, group=None, fn=None, key=None, rate=None, | ||
method=ALL, increment=False): | ||
usage = await aget_usage(request, group, fn, key, rate, method, increment) | ||
if usage is None: | ||
return False | ||
|
||
return usage['should_limit'] | ||
|
||
|
||
def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, | ||
increment=False): | ||
usage = _get_usage(request, group, fn, key, rate, method, increment) | ||
if usage is not None: | ||
return usage() | ||
|
||
async def aget_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, | ||
increment=False): | ||
usage = _get_usage(request, group, fn, key, rate, method, increment, is_async=True) | ||
if usage is not None: | ||
return await usage() | ||
|
||
|
||
def _get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, | ||
increment=False, is_async=False): | ||
if group is None and fn is None: | ||
raise ImproperlyConfigured('get_usage must be called with either ' | ||
'`group` or `fn` arguments') | ||
|
@@ -227,45 +248,102 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, | |
cache_key = _make_cache_key(group, window, rate, value, method) | ||
|
||
count = None | ||
try: | ||
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) | ||
except socket.gaierror: # for redis | ||
added = False | ||
if added: | ||
count = initial_value | ||
if is_async: | ||
async def inner(): | ||
try: | ||
# Some caches don't have an async implementation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried it before, and if someone is using an old cache this can fail. This is just to support caches that may not define the cache interface with those methods. Test implementations can also be missing this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify what you mean by "old cache"? Is it an implementation that doesn't inherit from |
||
if hasattr(cache, 'aadd'): | ||
added = await cache.aadd(cache_key, initial_value, period + EXPIRATION_FUDGE) | ||
else: | ||
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) | ||
except socket.gaierror: # for redis | ||
added = False | ||
if added: | ||
count = initial_value | ||
else: | ||
if increment: | ||
try: | ||
# python3-memcached will throw a ValueError if the server is | ||
# unavailable or (somehow) the key doesn't exist. redis, on the | ||
# other hand, simply returns None. | ||
if hasattr(cache, 'aincr'): | ||
count = await cache.aincr(cache_key) | ||
else: | ||
count = cache.incr(cache_key) | ||
except ValueError: | ||
pass | ||
else: | ||
if hasattr(cache, 'aget'): | ||
count = await cache.aget(cache_key, initial_value) | ||
else: | ||
count = cache.get(cache_key, initial_value) | ||
|
||
# Getting or setting the count from the cache failed | ||
if count is None or count is False: | ||
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): | ||
return None | ||
return { | ||
'count': 0, | ||
'limit': 0, | ||
'should_limit': True, | ||
'time_left': -1, | ||
} | ||
|
||
time_left = window - int(time.time()) | ||
return { | ||
'count': count, | ||
'limit': limit, | ||
'should_limit': count > limit, | ||
'time_left': time_left, | ||
} | ||
else: | ||
if increment: | ||
def inner(): | ||
try: | ||
# python3-memcached will throw a ValueError if the server is | ||
# unavailable or (somehow) the key doesn't exist. redis, on the | ||
# other hand, simply returns None. | ||
count = cache.incr(cache_key) | ||
except ValueError: | ||
pass | ||
else: | ||
count = cache.get(cache_key, initial_value) | ||
|
||
# Getting or setting the count from the cache failed | ||
if count is None or count is False: | ||
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): | ||
return None | ||
return { | ||
'count': 0, | ||
'limit': 0, | ||
'should_limit': True, | ||
'time_left': -1, | ||
} | ||
|
||
time_left = window - int(time.time()) | ||
return { | ||
'count': count, | ||
'limit': limit, | ||
'should_limit': count > limit, | ||
'time_left': time_left, | ||
} | ||
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) | ||
except socket.gaierror: # for redis | ||
added = False | ||
if added: | ||
count = initial_value | ||
else: | ||
if increment: | ||
try: | ||
# python3-memcached will throw a ValueError if the server is | ||
# unavailable or (somehow) the key doesn't exist. redis, on the | ||
# other hand, simply returns None. | ||
count = cache.incr(cache_key) | ||
except ValueError: | ||
pass | ||
else: | ||
count = cache.get(cache_key, initial_value) | ||
|
||
# Getting or setting the count from the cache failed | ||
if count is None or count is False: | ||
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): | ||
return None | ||
return { | ||
'count': 0, | ||
'limit': 0, | ||
'should_limit': True, | ||
'time_left': -1, | ||
} | ||
|
||
time_left = window - int(time.time()) | ||
return { | ||
'count': count, | ||
'limit': limit, | ||
'should_limit': count > limit, | ||
'time_left': time_left, | ||
} | ||
|
||
return inner | ||
|
||
|
||
|
||
is_ratelimited.ALL = ALL | ||
is_ratelimited.UNSAFE = UNSAFE | ||
ais_ratelimited.ALL = ALL | ||
ais_ratelimited.UNSAFE = UNSAFE | ||
get_usage.ALL = ALL | ||
get_usage.UNSAFE = UNSAFE | ||
aget_usage.ALL = ALL | ||
aget_usage.UNSAFE = UNSAFE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,29 +5,44 @@ | |
|
||
from django_ratelimit import ALL, UNSAFE | ||
from django_ratelimit.exceptions import Ratelimited | ||
from django_ratelimit.core import is_ratelimited | ||
from django_ratelimit.core import is_ratelimited, ais_ratelimited | ||
from asgiref.sync import iscoroutinefunction | ||
|
||
|
||
__all__ = ['ratelimit'] | ||
|
||
|
||
def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): | ||
def decorator(fn): | ||
@wraps(fn) | ||
def _wrapped(request, *args, **kw): | ||
old_limited = getattr(request, 'limited', False) | ||
ratelimited = is_ratelimited(request=request, group=group, fn=fn, | ||
key=key, rate=rate, method=method, | ||
increment=True) | ||
request.limited = ratelimited or old_limited | ||
if ratelimited and block: | ||
cls = getattr( | ||
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) | ||
raise (import_string(cls) if isinstance(cls, str) else cls)() | ||
return fn(request, *args, **kw) | ||
if iscoroutinefunction(fn): | ||
Comment on lines
15
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not something to change here, but I wonder if having a second There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Async is barely complete even in Django 5.0, so I doubt anybody would care if 3.2 doesn't support the async decorator. If you're doing async work, you're probably running the latest version of Django. |
||
@wraps(fn) | ||
async def _wrapped(request, *args, **kw): | ||
old_limited = getattr(request, 'limited', False) | ||
ratelimited = await ais_ratelimited(request=request, group=group, fn=fn, | ||
key=key, rate=rate, method=method, | ||
increment=True) | ||
request.limited = ratelimited or old_limited | ||
if ratelimited and block: | ||
cls = getattr( | ||
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) | ||
raise (import_string(cls) if isinstance(cls, str) else cls)() | ||
return await fn(request, *args, **kw) | ||
else: | ||
@wraps(fn) | ||
def _wrapped(request, *args, **kw): | ||
old_limited = getattr(request, 'limited', False) | ||
ratelimited = is_ratelimited(request=request, group=group, fn=fn, | ||
key=key, rate=rate, method=method, | ||
increment=True) | ||
request.limited = ratelimited or old_limited | ||
if ratelimited and block: | ||
cls = getattr( | ||
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) | ||
raise (import_string(cls) if isinstance(cls, str) else cls)() | ||
return fn(request, *args, **kw) | ||
return _wrapped | ||
return decorator | ||
|
||
return decorator | ||
|
||
ratelimit.ALL = ALL | ||
ratelimit.UNSAFE = UNSAFE |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
|
||
def my_ip(req): | ||
return req.META['MY_THING'] | ||
|
||
def callable_rate(group, request): | ||
if request.user.is_authenticated: | ||
return None | ||
return (0, 1) | ||
|
||
def mykey(group, request): | ||
return request.META['REMOTE_ADDR'][::-1] | ||
|
||
class CustomRatelimitedException(Exception): | ||
pass |
Uh oh!
There was an error while loading. Please reload this page.