概述
限流机制可以不让用户访问某一接口过于频繁,例如短信服务。
限制访问频率的前提是找到访问的唯一标识。对于已登录用户,可以用用户信息主键、ID、用户名作为唯一标识。对于未登录用户,一般用IP作为唯一标识(容易被代理IP绕过),再加算法js。
限制方法(以10分钟访问3次的限制为例):
- 维护一个记录访问时间的列表,列表名是用户的唯一标识
- 收到请求后,将当前时间记录进列表
- 删除列表中超过当前时间十分钟的访问记录
- 计算列表长度,超过则触发限流,报错;未超过则允许访问
快速使用
DRF的基础限流类
from rest_framework.throttling import BaseThrottle
class BaseThrottle:
# 限流类的核心
# allow_request()返回True意味着没有触发限流
def allow_request(self, request, view):
# 如果请求应给被允许,返回True, 否则返回False
raise NotImplementedError('.allow_request() must be overridden')
# 获取唯一标示
# 基础限流类实际上获取的是用户的IP地址
def get_ident(self, request):
# 通过请求的参数HTTP_X_FORWARDED_FOR获取机器的唯一标识(如果存在并且代理数量>0)
# If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
def wait(self):
# 请求触发限流,即allow_request()返回False,执行等待
# 可以选择返回需要等待的时间(单位是秒)
return None
一个更好用的在基础限流类之上扩展的限流类:SimpleRateThrottle
from rest_framework.throttling import SimpleRateThrottle
class SimpleRateThrottle(BaseThrottle):
# 一个简单的缓存实现,只需要重写`.get_cache_key()`方法
"""
The rate (requests / seconds) is set by a `rate` attribute on the Throttle
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# 生成一个唯一标识
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
# 依然是最核心的功能
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
在此基础之上,构建自己的限流类:
from rest_framework.throttling import BaseThrottle, SimpleRateThrottle
from django.core.cache import cache as default_cache
class MyThrottle(SimpleRateThrottle):
scope = "my_throttle"
THROTTLE_RATES = {"my_throttle": "5/m"} # 定义访问频率:一分钟五次
cache = default_cache # settings.py中关于缓存的设置
# 返回一个更不容易重复的标识:throttle_%(scope)s_%(ident)s
# 将访问记录存于redis中,cache_key相当于redis中的键
def get_cache_key(self, request, view):
if request.user:
# 已登录用户使用user.id作为标识
ident = request.user.pk
else:
# 未登录用户使用BaseThrottle的get_ident方法,以IP地址作为标识
ident = self.get_ident(request) # 未登录用户使用BaseThrottle
# SimpleRateThrottle的cache_format = 'throttle_%(scope)s_%(ident)s'
# 做字符串格式化处理
return self.cache_format % {'scope': self.scope, 'ident': ident}
其中default_cache要在settings.py中设置,这里我们以redis作为我们的缓存:
CACHES = {
'default': {
# 安装django-redis库
'BACKEND': 'django_redis.cache.RedisCache',
'LOCATION': 'redis://127.0.0.1:6379/0',
'OPTIONS': {
'CLIENT_CLASS': 'django_redis.client.DefaultClient',
'PASSWORD': '<PASSWORD>'
}
}
}
在将我们创建的限流类应用到视图中:
class LoginView(APIView):
throttle_classes = [MyThrottle, ]
对象的创建与加载
限流类的实例化:
class APIView(View):
# 限流类的全局配置
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
...
def dispatch(self, request, *args, **kwargs):
# `.dispatch()` is pretty much the same as Django's regular dispatch,
# but with extra hooks for startup, finalize, and exception handling.
self.args = args
self.kwargs = kwargs
# 请求封装
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
# 在这里执行认证、权限、限流三件套
self.initial(request, *args, **kwargs)
...
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
def initial(self, request, *args, **kwargs):
...
# Ensure that the incoming request is permitted
self.perform_authentication(request) # 认证
self.check_permissions(request) # 权限
self.check_throttles(request) # 限流
```
def check_throttles(self, request):
# Check if request should be throttled.
# Raises an appropriate exception if the request is throttled.
throttle_durations = []
# get_trrottles方法会回去限流类实例的列表
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
def get_throttles(self):
# 实例化限流列表中的每一个限流类
return [throttle() for throttle in self.throttle_classes]
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_rate(self):
# Determine the string representation of the allowed request rate.
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# 在MyThrottle中定义了THROTTLE_RATES = {"my_throttle": "5/m"}
# 这里相当于从一个字典取值,得到此限流类的限制频率
# 也可以在settings.py中定义全局的THROTTLR_RATES
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
# Given the request rate string, return a two tuple of:
# <allowed number of requests>, <period of time in seconds>
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
# period[0]会取period字符串的第一个字符,这样"10/min"或者"33/hour"也能找到正确的键
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
请求进来,先执行dispatch方法,然后执行initial方法,最后执行check_throttles方法。在check_throttles方法中,从get_thrrottles方法中获取到视图类的throttle_classes属性中的每一个限流类的实例组成的列表。
初始化先流泪的时候执行,首先执行了get_rate方法。这样会尝试获取该限流类的访问频率限制,是从THROTTLE_RATES这个属性中获取。或者也可以在settings.py的rest framework配置中定义:
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication',
),
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
),
'DEFAULT_THROTTLE_RATES': {
"MyThrottle": "5/m",
"AnotherThrottle": "10/m",
}
}
get_rate方法解析的访问频率会在init方法中的parse_rate方法解析,赋值给self.num_requests, self.duration 两个属性。num_requests是访问次数,duration是单位时间,以秒计算。
综上所述,获取每个限流类对象的过程,就是每个限流类获取自己的rate属性,即该限流类的时间间隔以及访问次数的过程。
限流类的执行
在ApiView视图类的check_throttles方法中,循环获取每个视图类的实例之后,执行每个实例的allow_rquest方法。如果返回为True,则check_throttles执行完毕;若返回False,则从限流类的wait方法中获取需等待的时间,并加入列表throttle_durations中。
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def allow_request(self, request, view):
# Implement the check to see if the request should be throttled.
# On success calls `throttle_success`.
# On failure calls `throttle_failure`.
if self.rate is None:
return True
# 获取用户的唯一标识
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 获取当前标识代表的用户的历史访问记录
self.history = self.cache.get(self.key, [])
# 获取当前时间戳,self.timer就是time.time
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
# 历史记录中,早于当前时间-时间间隔的历史记录,即过早的访问记录,剔除掉
# 剩下的访问记录都在限流的时间间隔内
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# history的长度就是访问次数
if len(self.history) >= self.num_requests:
# 超过了访问次数的限制,被限流
return self.throttle_failure()
# 未超过访问次数的限制,不被限流
return self.throttle_success()
def throttle_success(self):
# 将本次访问记录添加到历史访问数据中,并记录在缓存内
self.history.insert(0, self.now)
# 超过限流的时间间隔,访问记录会失效
self.cache.set(self.key, self.history, self.duration)
# 返回True,不被限流
return True
def throttle_failure(self):
# 被限流,返回False
return False
def wait(self):
# 返回下一个不被限流的访问需要多少秒
if self.history:
# 还需要等待多久
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
如果allow_request返回是False,那么wait方法会返回需要等待的时间;这些不同限流类的等待时间都存在throttle_duration列表中,check_throttle方法会选出等待的最大时间,并通过throttled方法返回一个报错。
class APIView(View):
def throttled(self, request, wait):
# If request is throttled, determine what kind of exception to raise.
raise exceptions.Throttled(wait)
def dispatch(self, request, *args, **kwargs):
...
try:
self.initial(request, *args, **kwargs)
...
# 如果限流触发异常,会在这里被捕获
except Exception as exc:
response = self.handle_exception(exc)
...
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
extra_detail_singular = _('Expected available in {wait} second.')
extra_detail_plural = _('Expected available in {wait} seconds.')
default_code = 'throttled'
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
detail = ' '.join((
detail,
force_str(ngettext(self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait))))
self.wait = wait
super().__init__(detail, code)
用户登录简单应用
一个未登录用户访问的登录接口,对IP进行限制:
class IPThrottle(SimpleRateThrottle):
scope = "ip_throttle"
cache = default_cache # settings.py中关于缓存的设置
def get_cache_key(self, request, view):
ident = self.get_ident(request) # 对于未登录用户,用ip作为其标识符
return self.cache_format % {'scope': self.scope, 'ident': ident}
class UserThrottle(SimpleRateThrottle):
scope = "user_throttle"
cache = default_cache
def get_cache_key(self, request, view):
ident = request.user.pk # 对于登录用户,用用户id作为其标识符
return self.cache_format % {'scope': self.scope, 'ident': ident}
class LoginView(APIView):
throttle_classes = [IPThrottle, ]
class UserVIew(APIView):
throttle_classes = [UserThrottle, ]
别忘记在settings.py中设置THTORRLE_RATES:
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication',
),
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
),
'DEFAULT_THROTTLE_RATES': {
"ip_throttle": "20/m",
"user_throttle": "10/m",
}
}