1.限流组件概述
限流,限制用户访问频率,例如:用户1分钟最多访问100次 或者 短信验证码一天每天可以发送50次, 防止盗刷。
- 已登录用户,用户信息主键、ID、用户名
- 未登录,IP为唯一标识
2.限流组件应用
2.1 编写类
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache # 连接redis
class IpThrottle(SimpleRateThrottle):
scope = "ip"
cache = default_cache
def get_cache_key(self, request, view):
ident = self.get_ident(request) # 获取请求用户IP(去request中找请求头)
return self.cache_format % {'scope': self.scope, 'ident': ident}
class UserThrottle(SimpleRateThrottle):
scope = "user"
cache = default_cache
def get_cache_key(self, request, view):
ident = request.user.pk # 用户ID
return self.cache_format % {'scope': self.scope, 'ident': ident}
2.2 安装django-redis并在settings中进行配置
2.2.1 redis相关配置
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
"PASSWORD": "密码",
}
}
}
2.2.2 限流频率相关配置
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
"ip": "10/m",
"user": "5/m"
}
}
2.3 启动redis服务
启动redis服务
2.4 局部应用
from rest_framework.views import APIView
from rest_framework.response import Response
from ext.throttle import MyThrottle
class LoginView(APIView):
throttle_classes = [MyThrottle, ]
def post(self, request):
# 1.接收用户POST提交的用户名和密码
# print(request.query_params)
user = request.data.get("username")
pwd = request.data.get("password")
# 2.数据库校验
user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
if not user_object:
return Response({"status": False, 'msg': "用户名或密码错误"})
# 3.正确
token = str(uuid.uuid4())
user_object.token = token
user_object.save()
return Response({"status": True, 'data': token})
3.限流组件源码分析
class IpThrottle(SimpleRateThrottle): # 自定义限流类
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache
scope = "ip"
cache = default_cache
def get_cache_key(self, request, view):
ident = self.get_ident(request) # 获取请求用户IP(去request中找请求头)
return self.cache_format % {'scope': self.scope, 'ident': ident}
class UserThrottle(SimpleRateThrottle): # 自定义限流类
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache
scope = "user"
cache = default_cache
def get_cache_key(self, request, view):
ident = request.user.pk # 用户ID
return self.cache_format % {'scope': self.scope, 'ident': ident}
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# 核心代码-->5 获取限流的字符串,例如:"10/m"
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
# 核心代码-->6 对获取的字符串进行解析,解析结果为限流设置的访问频率数字和限流设置访问频率对应时、分、秒换算成秒的数字。
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] # 这代表着rate写成‘hour’、‘date’等也可以;
return (num_requests, duration)
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# 核心代码-->8 执行限流的过程
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
return False
def wait(self):
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
class APIView(View):
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
def get_throttles(self):
# 核心代码-->4
return [throttle() for throttle in self.throttle_classes]
def check_throttles(self, request):
throttle_durations = []
# 核心代码--> 3 获取限流对象:get_throttles()
for throttle in self.get_throttles():
# 核心代码--> 7 执行限流 allow_request(request, self),返回为True时,通过限流;
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
def initial(self, request, *args, **kwargs):
self.format_kwarg = self.get_format_suffix(**kwargs)
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
self.perform_authentication(request)
self.check_permissions(request)
# 核心代码--> 2
self.check_throttles(request)
def dispatch(self, request, *args, **kwargs):
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
# 核心代码--> 1
self.initial(request, *args, **kwargs)
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response