drf中认证组件(authentication)源码分析

62 阅读2分钟

1.认证组件概述

在开发API过程中,有些功能需要登录才能访问,有些无需登录。drf中的认证组件主要就是用来实现此功能。 认证组件中,如果是使用了多个认证类,会按照顺序逐一执行其中的authenticate方法。

  • 返回None或无返回值,表示继续执行后续的认证类;
  • 返回 (user, auth) 元组,则不再继续执行后续的类并将值赋值给request.userrequest.auth
  • 抛出异常 AuthenticationFailed(...),认证失败,不再继续向后走。

2.源码分析

2.1 APIView(View)

源码入口是dispatch(self, request, *args, **kwargs)方法;

  • 通过调用initialize_request方法对request进行重新赋值,形成drf的request,将所有的认证类封装进drf里面的request中;
  • 通过调用initial方法,执行认证相关的功能。
class APIView(View):
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

    @auth.setter
    def auth(self, value):
        self._auth = value
        self._request.auth = value

    @user.setter
    def user(self, value):    	
        self._user = value
        self._request.user = value


    def _not_authenticated(self):

        self._authenticator = None

        if api_settings.UNAUTHENTICATED_USER:
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

    def _authenticate(self):
 
        for authenticator in self.authenticators:
            try:
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:

            	# 核心代码-->9 认证通过,将user, auth(token)进行赋值
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple
                return

        # 认证不通过,执行_not_authenticated方法
        self._not_authenticated()


    @property
    def user(self):
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():

            	# 核心代码--> 8	通过_authenticate方法执行认证;
                self._authenticate()
        return self._user

    def perform_authentication(self, request):
        # 核心代码--> 7 查找request下的user方法,执行认证;
        request.user


    def initial(self, request, *args, **kwargs):
        
        self.format_kwarg = self.get_format_suffix(**kwargs)

        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # 核心代码--> 6 通过perform_authentication方法执行认证;
        self.perform_authentication(request)

        self.check_permissions(request)
        self.check_throttles(request)

    def get_authenticators(self):
        # 核心代码--> 4 通过自生成列表方法生成各个认证对象()[认证对象1, 认证对象2,  认证对象3……]
        # self.authentication_classes 表示如果当前视图中没有定义认证类,则通过APIView(View)查找:api_settings.DEFAULT_AUTHENTICATION_CLASSES
        
        return [auth() for auth in self.authentication_classes]


    def initialize_request(self, request, *args, **kwargs):
        parser_context = self.get_parser_context(request)  

        # 核心代码   -->2 使用Request,将request,authenticators等封装进drf中request;
        return Request(
            request,
            parsers=self.get_parsers(),

            # 核心代码--> 3 将所有的认证对象封装进dispatch里面的request中;
            authenticators=self.get_authenticators(),

            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )


    def dispatch(self, request, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

        # 核心代码--> 1 调用initialize_request方法对request进行重新赋值,形成drf的request;
        request = self.initialize_request(request, *args, **kwargs)        
        self.request = request

        self.headers = self.default_response_headers  # deprecate?
        try:

            # 核心代码--> 5 执行认证
            self.initial(request, *args, **kwargs)

            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        # 将response进行封装
        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response


class UserView(APIView):
    authentication_classes = [认证组件1, 认证组件2,……]
    def get(self, request, *args, **kwargs):
        return Response('user')

2.2 Request

对django中的request进行进一步的封装,增加drf相关内容。

class Request:

    def __init__(self, request, parsers=None, authenticators=None,
                 negotiator=None, parser_context=None):
        assert isinstance(request, HttpRequest), (

            .format(request.__class__.__module__, request.__class__.__name__)
        )
        self._request = request
        self.parsers = parsers or ()

        # 核心代码--> 3 初始化赋值authenticators
        self.authenticators = authenticators or ()

        self.negotiator = negotiator or self._default_negotiator()
        self.parser_context = parser_context
        self._data = Empty
        self._files = Empty
        self._full_data = Empty
        self._content_type = Empty
        self._stream = Empty

        if self.parser_context is None:
            self.parser_context = {}
        self.parser_context['request'] = self
        self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET

        force_user = getattr(request, '_force_auth_user', None)
        force_token = getattr(request, '_force_auth_token', None)
        if force_user is not None or force_token is not None:
            forced_auth = ForcedAuthentication(force_user, force_token)
            self.authenticators = (forced_auth,)

2.3 ForcedAuthentication

赋值操作,返回结果为元组。

class ForcedAuthentication:

    def __init__(self, force_user, force_token):
        self.force_user = force_user
        self.force_token = force_token

    def authenticate(self, request):
        return (self.force_user, self.force_token)