drf-限流组件的使用与源码分析

258 阅读4分钟

drf-限流组件的使用与源码分析

频率限制在认证、权限之后

本质:将存储当前客户访问的时间戳到一个列表中,通过列表的长度和对列表的首个元素对比来进行频率的限制

  • 匿名用户,用IP作为用户唯一标记,但如果用户换代理IP,无法做到真正的限制。
  • 登录用户,用用户名或用户ID做标识。
限制:60s能访问3次
来访问时:
    1.获取当前时间 100121280
    2.100121280-60 = 100121220,小于100121220所有记录删除
    3.判断1分钟以内已经访问多少次了? 4 
    4.无法访问
停一会
来访问时:
    1.获取当前时间 100121340
    2.100121340-60 = 100121280,小于100121280所有记录删除
    3.判断1分钟以内已经访问多少次了? 0
    4.可以访问

使用方式

依赖缓存,所以得先配置好redis

缓存={
    用户标识:[12:33,12:32,12:31,12:30,12,]    1小时/5次   12:34   11:34
}
pip3 install django-redis
# settings.py
CACHES = {
    "default": {
        "BACKEND": "django_redis.cache.RedisCache",
        "LOCATION": "redis://127.0.0.1:6379",
        "OPTIONS": {
            "CLIENT_CLASS": "django_redis.client.DefaultClient",
            "PASSWORD": "qwe123",
        }
    }
}

使用自定义的限流类

自定义限流类

from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache
​
​
class MyRateThrottle(SimpleRateThrottle):
    cache = default_cache  # 访问记录存放在django的缓存中(需设置缓存)
    scope = "user"  # 构造缓存中的key
    cache_format = 'throttle_%(scope)s_%(ident)s'
​
    # 设置访问频率,例如:1分钟允许访问10次
    # 其他:'s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day'
    THROTTLE_RATES = {"user": "10/m"}
​
    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk  # 用户ID
        else:
            ident = self.get_ident(request)  # 获取请求用户IP(去request中找请求头)
​
        # throttle_user_11.11.11.11
        return self.cache_format % {'scope': self.scope, 'ident': ident}

视图类

class LoginView(APIView):
    authentication_classes = []  # 设置要应用的认真类
    permission_classes = []
    throttle_classes = [MyRateThrottle, ]
​
    def post(self, request, *args, **kwargs):
        username = request.data.get('username')
        password = request.data.get('password')
        user_object = models.UserInfo.objects.filter(username=username, password=password)
        if not user_object:
            return Response({'status': False, 'msg': '登录失败'})
        token = str(uuid.uuid4())
        user_object.update(token=token)
        return Response({'status': True, 'data': token})

settings设置全局限流次数

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ["ext.throttle.MyRateThrottle"],
    "DEFAULT_THROTTLE_RATES": {
        "user": '3/m'
    }
}

使用官方自带的限流类

在要使用频率限制的CBV里添上throttle_classes = [AnonRateThrottle,]

from rest_framework.views import APIView
from rest_framework.response import Response
​
from rest_framework.throttling import AnonRateThrottle
​
class ArticleView(APIView):
    # 频率限制
    throttle_classes = [AnonRateThrottle,]
    
    def get(self,request,*args,**kwargs):
        return Response('ArticleView')

频率限制setting配置

这样是限制每分钟3次

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
    "DEFAULT_THROTTLE_RATES":{
    	"anon":'3/m'
    }
}

源码简略分析

image-20220913170802344

image-20220913171542709

源码逐步分析

第一步:限流类对象的加载

①、请求进来是执行CBV的dispatch方法,这次依旧只需要关注initial方法

image-20220913095013273

②、进入initial方法,倒数第一行的限流组件,check_throttles

image-20220913144029688

③、进入check_throttles方法,重点要了解self.get_throttles这个方法

def check_throttles(self, request):

    throttle_durations = []
    for throttle in self.get_throttles():
        if not throttle.allow_request(request, self):
            throttle_durations.append(throttle.wait())

    if throttle_durations:
        durations = [
            duration for duration in throttle_durations
            if duration is not None
        ]

        duration = max(durations, default=None)
        self.throttled(request, duration)

④、进入self.get_throttles,与认证、权限组件类似,就是获取限流类的实例化对象列表

def get_throttles(self):

    return [throttle() for throttle in self.throttle_classes]

⑤、实例化对象就会执行__init__方法,自定义的类中没有,就会找到SimpleRateThrottle这个类的__init__方法

def __init__(self):
    if not getattr(self, 'rate', None):
    	self.rate = self.get_rate()
    self.num_requests, self.duration = self.parse_rate(self.rate)

最开始没有rate,所以一定为空,实例化就一定去调用self.get_rate()方法去给rate进行赋值image-20220913145845055

⑥、进到self.get_rate()方法,可以看到

def get_rate(self):
    if not getattr(self, 'scope', None):
        msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
               self.__class__.__name__)
        raise ImproperlyConfigured(msg)

    try:
        return self.THROTTLE_RATES[self.scope]
    except KeyError:
        msg = "No default throttle rate set for '%s' scope" % self.scope
        raise ImproperlyConfigured(msg)

image-20220913145414080

⑦、在进入到self.parse_rate

可以看出将传入的rate进行字符串切割,前面的值转换成int类型, 后面的根据首字母转换成对应的秒数,即多少秒内访问多少次

def parse_rate(self, rate):
    if rate is None:
    return (None, None)
    num, period = rate.split('/')
    num_requests = int(num)
    duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
    return (num_requests, duration)

总结:初始化对象时会获取THROTTLE_RATES的值(限制的配置)来获取到 访问次数+时间间隔 (num_requests, duration)

第二步:allow_request是否限流

image-20220913154018119

①回到上一步的③中的check_throttles方法,分析throttle.allow_request(request, self):,因为一般的自定义限流类或者用官方自带的限流类都是继承SimpleRateThrottle这个类,所以直接分析这个类中的allow_request方法

from rest_framework.throttling import SimpleRateThrottle
def allow_request(self, request, view):
    if self.rate is None:
        return True

    self.key = self.get_cache_key(request, view)
    if self.key is None:
        return True

    self.history = self.cache.get(self.key, [])
    self.now = self.timer()

    while self.history and self.history[-1] <= self.now - self.duration:
        self.history.pop()
    if len(self.history) >= self.num_requests:
        return self.throttle_failure()
    return self.throttle_success()

def throttle_success(self):
    self.history.insert(0, self.now)
    self.cache.set(self.key, self.history, self.duration)
    return True

def throttle_failure(self):
    return False

image-20220913153852972

②、再次回到check_throttles方法,此时可以根据allow_request返回的值来进行是否允许进去视图函数还是直接报错

image-20220913152326129

③、其中的wait()函数源码是这样通过计算到剩余的时间的

def wait(self):
    
    if self.history:
        # 这里就直接能计算出等待的时间
        remaining_duration = self.duration - (self.now - self.history[-1])
    else:
        remaining_duration = self.duration
	
    # 在调用的时候一般这个值就为1,所以返回的值就是等待时间/1就是等待时间
    available_requests = self.num_requests - len(self.history) + 1
    if available_requests <= 0:
        return None

    return remaining_duration / float(available_requests)

④、这个错误提示事调用了一个exceptions.Throttled(wait),所以我们也可以通过重写这个throttled方法来定制错误的信息

def throttled(self, request, wait):

	raise exceptions.Throttled(wait)

根据源码来自定义错误信息

本质重写throttled方法以及Throttled类

完整代码:

from rest_framework import exceptions
from rest_framework import status
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache


class ThrottledException(exceptions.APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_code = 'throttled'

    
class MyRateThrottle(SimpleRateThrottle):
    scope = "user"  
    THROTTLE_RATES = {"user": "10/m"}

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)  
        return self.cache_format % {'scope': self.scope, 'ident': ident}

    def throttle_failure(self):
        wait = self.wait()
        detail = {
            "code": 1005,
            "data": "访问频率限制",
            'detail': "需等待{}s才能访问".format(int(wait))
        }
        raise ThrottledException(detail)

\