drf限流组件(throttle)应用及源码分析

208 阅读2分钟

1.限流组件概述

限流,限制用户访问频率,例如:用户1分钟最多访问100次 或者 短信验证码一天每天可以发送50次, 防止盗刷。

  • 已登录用户,用户信息主键、ID、用户名
  • 未登录,IP为唯一标识

2.限流组件应用

2.1 编写类

from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache    # 连接redis


class IpThrottle(SimpleRateThrottle):
    scope = "ip"
    cache = default_cache

    def get_cache_key(self, request, view):
        ident = self.get_ident(request)  # 获取请求用户IP(去request中找请求头)
        return self.cache_format % {'scope': self.scope, 'ident': ident}


class UserThrottle(SimpleRateThrottle):
    scope = "user"
    cache = default_cache

    def get_cache_key(self, request, view):
        ident = request.user.pk  # 用户ID
        return self.cache_format % {'scope': self.scope, 'ident': ident}

2.2 安装django-redis并在settings中进行配置

2.2.1 redis相关配置

CACHES = {
    "default": {
        "BACKEND": "django_redis.cache.RedisCache",
        "LOCATION": "redis://127.0.0.1:6379",
        "OPTIONS": {
            "CLIENT_CLASS": "django_redis.client.DefaultClient",
            "PASSWORD": "密码",
        }
    }
}

2.2.2 限流频率相关配置

REST_FRAMEWORK = {    
    "DEFAULT_THROTTLE_RATES": {
        "ip": "10/m",
        "user": "5/m"
    }
}

2.3 启动redis服务

启动redis服务

2.4 局部应用

from rest_framework.views import APIView
from rest_framework.response import Response
from ext.throttle import MyThrottle

class LoginView(APIView):
    throttle_classes = [MyThrottle, ]

    def post(self, request):
        # 1.接收用户POST提交的用户名和密码
        # print(request.query_params)
        user = request.data.get("username")
        pwd = request.data.get("password")

        # 2.数据库校验
        user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
        if not user_object:
            return Response({"status": False, 'msg': "用户名或密码错误"})

        # 3.正确
        token = str(uuid.uuid4())
        user_object.token = token
        user_object.save()

        return Response({"status": True, 'data': token})

3.限流组件源码分析

class IpThrottle(SimpleRateThrottle):    # 自定义限流类
    from rest_framework.throttling import SimpleRateThrottle
    from django.core.cache import cache as default_cache

    scope = "ip"
    cache = default_cache

    def get_cache_key(self, request, view):
        ident = self.get_ident(request)  # 获取请求用户IP(去request中找请求头)
        return self.cache_format % {'scope': self.scope, 'ident': ident}


class UserThrottle(SimpleRateThrottle):    # 自定义限流类
    from rest_framework.throttling import SimpleRateThrottle
    from django.core.cache import cache as default_cache

    scope = "user"
    cache = default_cache

    def get_cache_key(self, request, view):
        ident = request.user.pk  # 用户ID
        return self.cache_format % {'scope': self.scope, 'ident': ident}

class SimpleRateThrottle(BaseThrottle):

    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):

        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            # 核心代码-->5 获取限流的字符串,例如:"10/m"
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        
        # 核心代码-->6 对获取的字符串进行解析,解析结果为限流设置的访问频率数字和限流设置访问频率对应时、分、秒换算成秒的数字。
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]  # 这代表着rate写成‘hour’、‘date’等也可以;
        return (num_requests, duration)

    def allow_request(self, request, view):        
        if self.rate is None:
            return True
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()
        
        # 核心代码-->8 执行限流的过程
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        return False

    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None
        return remaining_duration / float(available_requests)

class APIView(View):
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
    
    
    def get_throttles(self):
        # 核心代码-->4
        return [throttle() for throttle in self.throttle_classes]
    
    def check_throttles(self, request):
        throttle_durations = []
        
        # 核心代码--> 3 获取限流对象:get_throttles()
        for throttle in self.get_throttles():
            # 核心代码--> 7 执行限流 allow_request(request, self),返回为True时,通过限流;
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]
            duration = max(durations, default=None)
            self.throttled(request, duration)
    
    def initial(self, request, *args, **kwargs):
        self.format_kwarg = self.get_format_suffix(**kwargs)
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme
        self.perform_authentication(request)
        self.check_permissions(request)
        
        # 核心代码--> 2
        self.check_throttles(request)
    
    def dispatch(self, request, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?
        try:
            
            # 核心代码--> 1
            self.initial(request, *args, **kwargs)

            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed
            response = handler(request, *args, **kwargs)
        except Exception as exc:
            response = self.handle_exception(exc)
        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response