diff --git a/django_ratelimit/core.py b/django_ratelimit/core.py index 1270799..6ae1e85 100644 --- a/django_ratelimit/core.py +++ b/django_ratelimit/core.py @@ -1,6 +1,6 @@ -import ipaddress import functools import hashlib +import ipaddress import re import socket import time @@ -10,10 +10,8 @@ from django.core.cache import caches from django.core.exceptions import ImproperlyConfigured from django.utils.module_loading import import_string - from django_ratelimit import ALL, UNSAFE - __all__ = ['is_ratelimited', 'get_usage'] _PERIODS = { @@ -109,12 +107,12 @@ def _split_rate(rate): return count, seconds -def _get_window(value, period): +def _get_window(value, period, timestamp=None): """ Given a value, and time period return when the end of the current time period for rate evaluation is. """ - ts = int(time.time()) + ts = timestamp or int(time.time()) if period == 1: return ts if not isinstance(value, bytes): @@ -159,6 +157,20 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None, def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, increment=False): + """ + Call get_usage_extended and strip out usage_details for backwards compatibility. + """ + usage = get_usage_extended(request, group, fn, key, rate, method, increment) + if usage is None: + return None + return { + result_key: usage[result_key] for result_key in usage + if result_key != 'usage_details' + } + + +def get_usage_extended(request, group=None, fn=None, key=None, rate=None, method=ALL, + increment=False): if group is None and fn is None: raise ImproperlyConfigured('get_usage must be called with either ' '`group` or `fn` arguments') @@ -219,7 +231,8 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, raise ImproperlyConfigured( 'Could not understand ratelimit key: %s' % key) - window = _get_window(value, period) + timestamp = int(time.time()) + window = _get_window(value, period, timestamp) initial_value = 1 if increment else 0 cache_name = getattr(settings, 'RATELIMIT_USE_CACHE', 'default') @@ -245,6 +258,19 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, else: count = cache.get(cache_key, initial_value) + # Collect the usage details for logging + usage_details = { + 'rate': rate, + 'period': period, + 'group': group, + 'key': key, + 'value': value, + 'timestamp': timestamp, + 'window': window, + 'cache_key': cache_key, + 'added': added, + } + # Getting or setting the count from the cache failed if count is None or count is False: if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): @@ -254,6 +280,7 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, 'limit': 0, 'should_limit': True, 'time_left': -1, + 'usage_details': usage_details, } time_left = window - int(time.time()) @@ -262,6 +289,7 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, 'limit': limit, 'should_limit': count > limit, 'time_left': time_left, + 'usage_details': usage_details, } diff --git a/django_ratelimit/exceptions.py b/django_ratelimit/exceptions.py index f39a0f4..bb6101c 100644 --- a/django_ratelimit/exceptions.py +++ b/django_ratelimit/exceptions.py @@ -1,5 +1,19 @@ +import json + from django.core.exceptions import PermissionDenied +from django.core.serializers.json import DjangoJSONEncoder class Ratelimited(PermissionDenied): - pass + + def __init__(self, *args, usage=None, **kwargs): + self.usage = usage + super().__init__(*args, **kwargs) + # If python >=3.11 (has add_note), then add jsonified self.usage as note + if hasattr(self, "add_note"): + self.add_note("Usage: " + json.dumps( + self.usage, + indent=2, + cls=DjangoJSONEncoder, + default=lambda obj: str(obj), + )) diff --git a/docs/usage.rst b/docs/usage.rst index d9c0c99..2db748b 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -211,8 +211,8 @@ Core Methods In some cases the decorator is not flexible enough to, e.g., conditionally apply rate limits. In these cases, you can access the core -functionality in ``ratelimit.core``. The two major methods are -``get_usage`` and ``is_ratelimited``. +functionality in ``ratelimit.core``. The three major methods are +``get_usage``, ``get_usage_extended``, and ``is_ratelimited``. .. code-block:: python @@ -260,6 +260,15 @@ functionality in ``ratelimit.core``. The two major methods are the current count, limit, time left in the window, and whether this request should be limited. +.. py:function:: get_usage_extended(request, group=None, fn=None, key=None, \ + rate=None, method=ALL, increment=False) + + :returns dict or None: + The same as ``get_usage`` with additional information about the + rate limit, including the rate, period, group, key, value, timestamp, + window, cache_key, and if the cache was added. + + .. py:function:: is_ratelimited(request, group=None, fn=None, \ key=None, rate=None, method=ALL, \ increment=False) @@ -351,3 +360,47 @@ To use it, add ``django_ratelimit.middleware.RatelimitMiddleware`` to your The view specified in ``RATELIMIT_VIEW`` will get two arguments, the ``request`` object (after ratelimit processing) and the exception. + + +Extension Example +========== + +There are cases when you might want to extend the functionality of the +ratelimit decorator. For example, you might want to add more information +to the exception message. Here is an example of how to do that: + +.. code-block:: python + + from functools import wraps + + from django.conf import settings + from django.utils.module_loading import import_string + from django_ratelimit.core import get_usage_extended + from django_ratelimit.decorators import ALL + from django_ratelimit.exceptions import Ratelimited + + def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): + def decorator(fn): + @wraps(fn) + def _wrapped(request, *args, **kw): + old_limited = getattr(request, 'limited', False) + usage = get_usage_extended( + request=request, + group=group, + fn=fn, + key=key, + rate=rate, + method=method, + increment=True + ) + ratelimited = usage['should_limit'] if usage else False + request.limited = ratelimited or old_limited + if ratelimited and block: + cls = getattr( + settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + raise (import_string(cls) if isinstance(cls, str) else cls)(usage=usage) + return fn(request, *args, **kw) + + return _wrapped + + return decorator