DRF 源码解析 认证组件(二)

86 阅读5分钟

认证组件应用

认证组件实现用户授权的校验。

用户登录成功后,后台会返回一个凭证(token),保持登录状态。之后登录用户发送请求会携带token。

自定义一个认证类:

from rest_framework.authentication import BaseAuthentication

class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):
        ...
    def authenticate_header(self, request):
        ...

自定义认证类继承自BaseAuthentication

from rest_framework.exceptions import AuthenticationFailed

class BaseAuthentication:
    # 所有认证类都必须在BaseAuthentication基础之上扩展

    def authenticate(self, request):
        # 认证类的核心方法
        # 去做用户认证:1.读取token 2.校验token的合法性
        # 认证成功返回一个二元组(user, token),并赋值给request.user和request.auth
        # 认证失败可以抛出异常或者返回`None`(多个认证类)
        # 抛出异常 raise AuthenticationFailed("认证失败")
        ...

    def authenticate_header(self, request):
        # 对于`401 Unauthenticated`的响应,返回一个字符串作为标头中`WWW-Authenticate`的值
        # 对于`403 Permission Denied`响应,返回`None`
        ...

在视图类中添加认证类列表

class OrderView(APIView):
    authentication_classes = [MyAuthentication, ]
    # 为视图类OrderView添加认证类列表
    # 现在访问OrderView视图,需要经过MyAuthentication认证类
    def get(self, request):
        return Response("访问成功")

也可以在settings.py中全局配置认证类列表

REST_FRAMEWORK = {
    ...
    "DEFAULT_AUTHENTICATION_CLASSES": [app01.auths.MyAuthentication, ],
    ...
}

访问某个视图类时,DRF优先去找全局的验证类列表,再去找这个视图类内部设置的认证类列表;如果视图类中没有单独设置认证类列表,则使用全局设置的认证类列表,否则使用视图类内部的认证类列表。

面向对象知识:继承

class Base(object):
    a = "attr"
    
    def f1(self):
        self.f2()
        print(self.a)
        ...
        
    def f2(self):
        ...
        
class Foo(Base):

    def f2(self):
        ...
       
obj = Foo()
obj.f1() 

执行的是Base类的f1()方法,在这个方法内的self.f2()方法调用的是Foo类中f2()方法,方法内的self.a是Base类的a属性

认证组件源码流程

class APIView(View):  
    ...
    # 在settings.py中全局配置的认证类列表
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
    ...
    def dispatch(self, request, *args, **kwargs):
        ...
        # 第一步:请求的封装(Django的request对象 + authenticators认证组件)
        # 封装的请求有认证类列表
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?
        
        try:
            # 第二步:进行验证
            self.initial(request, *args, **kwargs)
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
    
    
    def initialize_request(self, request, *args, **kwargs):
        
        ...
        return Request(
            request,
            authenticators=self.get_authenticators(),
            ...
        )
    
    
    def get_authenticators(self):
        # 返回此视图可以使用的认证类的实例列表。
        # 先去视图类找authentication_classes属性
        # 视图类没有,去APIView类中找authentication_classes
        return [auth() for auth in self.authentication_classes]
    
    
    def initial(self, request, *args, **kwargs):
        ...
        # 这里的request是封装后的request
        self.perform_authentication(request)
        ...
      
    def perform_authentication(self, request):
        # 对传入请求执行身份验证
        # request.user是Request类的一个方法
        request.user
    
    
class Request:
    
    def __init__(self, request, authenticators=None, ...):
        self._request = request
        self.authenticators = authenticators or ()
        ...
    
    @property
    def user(self):
        # 返回与当前请求关联的用户,该用户已通过提供给请求的身份验证类进行身份验证。
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                self._authenticate()
        return self._user

    @user.setter
    def user(self, value):
        # 设置当前请求的用户。
        # 这对于保持与 django.contrib.auth 的兼容性是必要的,
        # 在 django.contrib.auth 中,用户属性是在Login和Logout函数中设置的。
        # 我们还设置了 Django 的底层 HttpRequest 实例上的用户,以确保它可用于堆栈中的任何中间件。
        self._user = value
        self._request.user = value
        
    def _authenticate(self):
        # 依次用认证类的实例列表中的认证类实例,去认证请求
        # self即是request对象
        for authenticator in self.authenticators:
            try:
                user_auth_tuple = authenticator.authenticate(self)
                # 认证类的authenticate()方法会返回一个二元组(user, token)
            except exceptions.APIException:
                # 如果有异常,执行_not_authenticated()方法
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:
                self._authenticator = authenticator
                # self._authenticator是成功认证的认证类实例
                self.user, self.auth = user_auth_tuple
                # 将(user, token)赋值给(request.user, request.auth)
                return
        # 如果认证类实例列表中的全部认证类实例都认证失败(全部返回`None`),执行_not_authenticated()方法
        self._not_authenticated()
        
  
    def _not_authenticated(self):
        # 为一个未认证的请求设置authenticator, user和authtoken
        # 默认(self._authenticator, self._user, self._auth) = (None, AnonymousUser, NOne)
        self._authenticator = None

        if api_settings.UNAUTHENTICATED_USER:
            # 读取匿名用户设置
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

可以设置单独的认证类,作为认证类列表的最后一项,用于抛出异常,不访问视图

from rest_framework.exceptions import AuthenticationFailed

class NotAuthentication(BaseAuthentication):

    def authenticate(self, request):
        raise AuthenticationFailed({"error_code": 20000, "error_reason": "认证失败"})
        
    def authenticate_header(self, request):
        ...

认证失败的状态码

class Request:
    ...
    def _authenticate(self):
        for authenticator in self.authenticators:
            try:
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                self._not_authenticated()
                # 认证失败的异常向上一级(request的user()方法)抛出
                raise
            ...
    
    @property
    def user(self):
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                self._authenticate()
                # 没有捕获异常,继续往上抛(APIView类的perform_authentication()方法)
        return self._user
    
    
class APIVIew(View):
    ...
    def perform_authentication(self, request):
        # 也没有捕获异常,继续向上抛(initial()方法)
        request.user
       
    def initial(self, request, *args, **kwargs):
        ...
        # 也没有捕获异常,继续向上抛(dispatch()方法)
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)
        
    def dispatch(self, request, *args, **kwargs):
        ...
        try:
            self.initial(request, *args, **kwargs)
            ...
        except Exception as exc:
            # 这里有异常处理,执行handle_exception()方法
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
        
        
    def handle_exception(self, exc):
        # 如果异常是认证失败(NotAuhenticated和AuthenticationFailed两种错误)
        # 执行get_authenticate_header()方法
        if isinstance(exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)):
        
            auth_header = self.get_authenticate_header(self.request)
            
         if auth_header:
                exc.auth_header = auth_header
            else:
                # 不存在auth_header的话,默认状态码是HTTP 403
                exc.status_code = status.HTTP_403_FORBIDDEN
        ...
            
     def get_authenticate_header(self, request):
        # 获取认证类实例的列表
        authenticators = self.get_authenticators()
        if authenticators:
            # 执行第一个认证类实力的authenticate_header()方法
            return authenticators[0].authenticate_header(request)

备注

request.data是请求体的全部数据。

class Request:
    ...
    @property  
    def data(self):  
        if not _hasattr(self, '_full_data'):  
            self._load_data_and_files()  
        return self._full_data
    
    def _load_data_and_files(self):  
        # 将请求内容解析为"self.data"
        if not _hasattr(self, '_data'):  
            self._data, self._files = self._parse()  
            if self._files:  
                self._full_data = self._data.copy()  
                self._full_data.update(self._files)  
            else:  
                self._full_data = self._data  

            # if a form media type, copy data & files refs to the underlying  
            # http request so that closable objects are handled appropriately.  
            if is_form_media_type(self.content_type):  
                self._request._post = self.POST  
                self._request._files = self.FILES         
    ...

request.query_params是url中携带的数据。

class Request:
    ...
    @property  
    def query_params(self):  
    # 语义上更正确的`request.GET`
    return self._request.GET
    ...

token可以用uuid库来生成

import uuid
token = str(uuid.uuid5())