DRF-限流组件的基本使用及源码剖析

67 阅读5分钟

限流组件的基本使用及源码剖析

使用方法有自己编写类和继承DRF的类

class BaseThrottle:
    def allow_request(self, request, view):
        raise NotImplementedError('.allow_request() must be overridden')
    def get_ident(self, request):
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr
    def wait(self):
        return None

继承后编写类

后面将在继承SimpleRateThrottle的基础上编写自定义的限流类:

class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):

        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):

        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):

        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):

        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()


        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):

        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
 
        return False

    def wait(self):

        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

编写自己的限流类:

class MyThrottle(SimpleRateThrottle):
    def get_cache_key(self, request, view):
        scope = 'xxx'
        THROTTLE_RATES = {'xxx':'5/m'}
        if request.user:
            ident = request.user.id
        else:
            ident = self.get_ident(request)
        return self.cache_format % {'scope':self.scope, 'ident':ident}

限流类的使用

class login(APIView):
    throttle_classes = [MyThrottle,]
    def post(self,request,*args,**kwargs):
        return 

限流组件源码剖析

接收到请求后,从dispatch调用self.initial

   def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

  def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)

inital中的 self.check_throttles(request)会循环所有的限流类并执行限流类中的allow_request

    def get_throttles(self):
		#实例化自定义或全局的限流类并返回
    	return [throttle() for throttle in self.throttle_classes]
   
	def check_throttles(self, request):

        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

通过循环将我们的限流类全部实例化之后,会去执行限流类中的allow_request方法:

我们自己写的MyThrottle类继承SimpleRateThrottle

class SimpleRateThrottle(BaseThrottle):
    ...
    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
        
    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    
    def allow_request(self, request, view):
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

在__init__ 方法中,通过self.rate=self.get_rate()给rate赋值,self.get_rate()的返回值会先从我们自己定义的 MyThrottle中寻找:从THROTTLE_RATES这个字典中获取scope这个键的值

    scope = 'xxx'
    THROTTLE_RATES = {'xxx':'5/m'}

如果我们有多个限流类,可以在限流类中定义scope然后在setting中设置scope对对应的键的值

REST_FRAMEWORK={
        'DEFAULT_THROTTLE_RATES':{
        'xxx':'5/m',
        'x1':'10/m'
    }
}
#self.rate获取到我们设置的速率后是一个字符串,会由self.parse_rate(self.rate)进行解析

class SimpleRateThrottle(BaseThrottle):
    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
        
    def parse_rate(self, rate):
   #按/进行分割,支持's','m','h','d',因为[period[0]]只读取/后的第一个字符,所以'5/min'或者‘5/hour'也是没问题的
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

在循环实例每一个限流类时,self.num_requests, self.duration就可以获取到限制时间内的访问次数了,实例完后就会执行每一个限流类的allow_request() (

#这是在上面的dispatch->initial->check_throttles
    def check_throttles(self, request):
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
            #False就会将还需要等待的时间加入throttle_durations,将需要等待的时间加入throttle_durations
			throttle_durations.append(throttle.wait())
                
        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]
		#因为可能有多个限流组件,所以有会取等待的最长时间
            duration = max(durations, default=None)
            #抛出异常
            self.throttled(request, duration)

            
#####################超出限制速率,抛出异常#####################
            
    def throttled(self, request, wait):
        #超出限制速率,抛出异常
        raise exceptions.Throttled(wait) 
        
class Throttled(APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_detail = _('Request was throttled.')
    extra_detail_singular = _('Expected available in {wait} second.')
    extra_detail_plural = _('Expected available in {wait} seconds.')
    default_code = 'throttled'

    def __init__(self, wait=None, detail=None, code=None):
        if detail is None:
            detail = force_str(self.default_detail)
        if wait is not None:
            wait = math.ceil(wait)
            detail = ' '.join((
                detail,
                force_str(ngettext(self.extra_detail_singular.format(wait=wait),
                                   self.extra_detail_plural.format(wait=wait),
                                   wait))))
        self.wait = wait
        super().__init__(detail, code)  
        
  ##########################################################    
            
    def allow_request(self, request, view):
        if self.rate is None:
            return True
		
        #获取用户的唯一标识
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
        
		#获取历史访问记录:[12:15,12:14,12:13,]
        self.history = self.cache.get(self.key, [])
        #获取当前时间戳
        self.now = self.timer()

       #pop掉当前时间减去设定的限流的时间间隔外的历史记录        
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        #剩下的历史记录跟允许在限流的时间段内访问的次数比较
        if len(self.history) >= self.num_requests:
            #超过限制return False
            return self.throttle_failure()
        return self.throttle_success()
    
    #成功将此条记录放在最前面,并缓存在redis
    def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True
    
    
#如果allow_request返回了False,就会调用self.wait()
    def wait(self):
        if self.history:
            #计算出还需等待多久	
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)