DRF 源码解析 限流组件(四)

164 阅读8分钟

概述

限流机制可以不让用户访问某一接口过于频繁,例如短信服务。

限制访问频率的前提是找到访问的唯一标识。对于已登录用户,可以用用户信息主键、ID、用户名作为唯一标识。对于未登录用户,一般用IP作为唯一标识(容易被代理IP绕过),再加算法js。

限制方法(以10分钟访问3次的限制为例):

  1. 维护一个记录访问时间的列表,列表名是用户的唯一标识
  2. 收到请求后,将当前时间记录进列表
  3. 删除列表中超过当前时间十分钟的访问记录
  4. 计算列表长度,超过则触发限流,报错;未超过则允许访问

快速使用

DRF的基础限流类

from rest_framework.throttling import BaseThrottle

class BaseThrottle:
   
    # 限流类的核心
    # allow_request()返回True意味着没有触发限流
    def allow_request(self, request, view):
        # 如果请求应给被允许,返回True, 否则返回False
        raise NotImplementedError('.allow_request() must be overridden')

    # 获取唯一标示
    # 基础限流类实际上获取的是用户的IP地址
    def get_ident(self, request):
        # 通过请求的参数HTTP_X_FORWARDED_FOR获取机器的唯一标识(如果存在并且代理数量>0)
        #  If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        # 请求触发限流,即allow_request()返回False,执行等待
        # 可以选择返回需要等待的时间(单位是秒)
        return None

一个更好用的在基础限流类之上扩展的限流类:SimpleRateThrottle

from rest_framework.throttling import SimpleRateThrottle

class SimpleRateThrottle(BaseThrottle):
    # 一个简单的缓存实现,只需要重写`.get_cache_key()`方法
    """
    The rate (requests / seconds) is set by a `rate` attribute on the Throttle
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    # 生成一个唯一标识
    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    # 依然是最核心的功能
    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

在此基础之上,构建自己的限流类:

from rest_framework.throttling import BaseThrottle, SimpleRateThrottle
from django.core.cache import cache as default_cache


class MyThrottle(SimpleRateThrottle):
    scope = "my_throttle"
    THROTTLE_RATES = {"my_throttle": "5/m"}  # 定义访问频率:一分钟五次
    cache = default_cache  # settings.py中关于缓存的设置

    # 返回一个更不容易重复的标识:throttle_%(scope)s_%(ident)s
    # 将访问记录存于redis中,cache_key相当于redis中的键
    def get_cache_key(self, request, view):
        if request.user:
            # 已登录用户使用user.id作为标识
            ident = request.user.pk
        else:
            # 未登录用户使用BaseThrottle的get_ident方法,以IP地址作为标识
            ident = self.get_ident(request)  # 未登录用户使用BaseThrottle

        # SimpleRateThrottle的cache_format = 'throttle_%(scope)s_%(ident)s'
        # 做字符串格式化处理
        return self.cache_format % {'scope': self.scope, 'ident': ident}

其中default_cache要在settings.py中设置,这里我们以redis作为我们的缓存:

CACHES = {
    'default': {
        # 安装django-redis库
        'BACKEND': 'django_redis.cache.RedisCache',
        'LOCATION': 'redis://127.0.0.1:6379/0',
        'OPTIONS': {
            'CLIENT_CLASS': 'django_redis.client.DefaultClient',
            'PASSWORD': '<PASSWORD>'
        }
    }
}

在将我们创建的限流类应用到视图中:

class LoginView(APIView):
    throttle_classes = [MyThrottle, ]

对象的创建与加载

限流类的实例化:

class APIView(View):
    # 限流类的全局配置
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
    ...
    def dispatch(self, request, *args, **kwargs):
        # `.dispatch()` is pretty much the same as Django's regular dispatch,
        # but with extra hooks for startup, finalize, and exception handling.
        self.args = args
        self.kwargs = kwargs
        # 请求封装
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            # 在这里执行认证、权限、限流三件套
            self.initial(request, *args, **kwargs)
            ...
        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

    def initial(self, request, *args, **kwargs):
        ...
        # Ensure that the incoming request is permitted
        self.perform_authentication(request)  # 认证
        self.check_permissions(request)  # 权限
        self.check_throttles(request)  # 限流
       
       ```
    def check_throttles(self, request):
        # Check if request should be throttled.
        # Raises an appropriate exception if the request is throttled.
        throttle_durations = []
        # get_trrottles方法会回去限流类实例的列表
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)
    
    def get_throttles(self):
        # 实例化限流列表中的每一个限流类
        return [throttle() for throttle in self.throttle_classes]
        
        
class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
    
    def get_rate(self):
        # Determine the string representation of the allowed request rate.
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            # 在MyThrottle中定义了THROTTLE_RATES = {"my_throttle": "5/m"}
            # 这里相当于从一个字典取值,得到此限流类的限制频率
            # 也可以在settings.py中定义全局的THROTTLR_RATES
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)
            
    def parse_rate(self, rate):
        # Given the request rate string, return a two tuple of:
        # <allowed number of requests>, <period of time in seconds>
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        # period[0]会取period字符串的第一个字符,这样"10/min"或者"33/hour"也能找到正确的键
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

请求进来,先执行dispatch方法,然后执行initial方法,最后执行check_throttles方法。在check_throttles方法中,从get_thrrottles方法中获取到视图类的throttle_classes属性中的每一个限流类的实例组成的列表。

初始化先流泪的时候执行,首先执行了get_rate方法。这样会尝试获取该限流类的访问频率限制,是从THROTTLE_RATES这个属性中获取。或者也可以在settings.py的rest framework配置中定义:

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': (
        'rest_framework.authentication.SessionAuthentication',
    ),
    'DEFAULT_PERMISSION_CLASSES': (
        'rest_framework.permissions.IsAuthenticated',
    ),
    'DEFAULT_THROTTLE_RATES': {
        "MyThrottle": "5/m",
        "AnotherThrottle": "10/m",
    }
}

get_rate方法解析的访问频率会在init方法中的parse_rate方法解析,赋值给self.num_requests, self.duration 两个属性。num_requests是访问次数,duration是单位时间,以秒计算。

综上所述,获取每个限流类对象的过程,就是每个限流类获取自己的rate属性,即该限流类的时间间隔以及访问次数的过程。

限流类的执行

ApiView视图类的check_throttles方法中,循环获取每个视图类的实例之后,执行每个实例的allow_rquest方法。如果返回为True,则check_throttles执行完毕;若返回False,则从限流类的wait方法中获取需等待的时间,并加入列表throttle_durations中。

class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES  

    def allow_request(self, request, view):
        # Implement the check to see if the request should be throttled.
        # On success calls `throttle_success`.
        # On failure calls `throttle_failure`.

        if self.rate is None:
            return True

        # 获取用户的唯一标识
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        # 获取当前标识代表的用户的历史访问记录
        self.history = self.cache.get(self.key, [])
        # 获取当前时间戳,self.timer就是time.time
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        # 历史记录中,早于当前时间-时间间隔的历史记录,即过早的访问记录,剔除掉
        # 剩下的访问记录都在限流的时间间隔内
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        # history的长度就是访问次数
        if len(self.history) >= self.num_requests:
            # 超过了访问次数的限制,被限流
            return self.throttle_failure()
        
        # 未超过访问次数的限制,不被限流
        return self.throttle_success()

    def throttle_success(self):
        # 将本次访问记录添加到历史访问数据中,并记录在缓存内
        self.history.insert(0, self.now)
        # 超过限流的时间间隔,访问记录会失效
        self.cache.set(self.key, self.history, self.duration)
        # 返回True,不被限流
        return True
        
    def throttle_failure(self):
        # 被限流,返回False
        return False
        
    def wait(self):
        # 返回下一个不被限流的访问需要多少秒
        if self.history:
            # 还需要等待多久
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

如果allow_request返回是False,那么wait方法会返回需要等待的时间;这些不同限流类的等待时间都存在throttle_duration列表中,check_throttle方法会选出等待的最大时间,并通过throttled方法返回一个报错。

class APIView(View):
    def throttled(self, request, wait):
        # If request is throttled, determine what kind of exception to raise.
        raise exceptions.Throttled(wait)
        
    def dispatch(self, request, *args, **kwargs):
        ...
        try:
            self.initial(request, *args, **kwargs)
            ...
        # 如果限流触发异常,会在这里被捕获
        except Exception as exc:
            response = self.handle_exception(exc)
        ...
        
class Throttled(APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_detail = _('Request was throttled.')
    extra_detail_singular = _('Expected available in {wait} second.')
    extra_detail_plural = _('Expected available in {wait} seconds.')
    default_code = 'throttled'

    def __init__(self, wait=None, detail=None, code=None):
        if detail is None:
            detail = force_str(self.default_detail)
        if wait is not None:
            wait = math.ceil(wait)
            detail = ' '.join((
                detail,
                force_str(ngettext(self.extra_detail_singular.format(wait=wait),
                                   self.extra_detail_plural.format(wait=wait),
                                   wait))))
        self.wait = wait
        super().__init__(detail, code)

用户登录简单应用

一个未登录用户访问的登录接口,对IP进行限制:

class IPThrottle(SimpleRateThrottle):
    scope = "ip_throttle"
    cache = default_cache  # settings.py中关于缓存的设置

    def get_cache_key(self, request, view):
        ident = self.get_ident(request)  # 对于未登录用户,用ip作为其标识符
        return self.cache_format % {'scope': self.scope, 'ident': ident}

class UserThrottle(SimpleRateThrottle):
    scope = "user_throttle"
    cache = default_cache

    def get_cache_key(self, request, view):
        ident = request.user.pk  # 对于登录用户,用用户id作为其标识符
        return self.cache_format % {'scope': self.scope, 'ident': ident}
        
 class LoginView(APIView):
    throttle_classes = [IPThrottle, ]


class UserVIew(APIView):
    throttle_classes = [UserThrottle, ]       

别忘记在settings.py中设置THTORRLE_RATES:

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': (
        'rest_framework.authentication.SessionAuthentication',
    ),
    'DEFAULT_PERMISSION_CLASSES': (
        'rest_framework.permissions.IsAuthenticated',
    ),
    'DEFAULT_THROTTLE_RATES': {
        "ip_throttle": "20/m",
        "user_throttle": "10/m",
    }
}