DRF 源码解析 版本(五)

133 阅读3分钟

概述

考虑API的更新迭代,会有版本的不同。可能会出现多个版本同时存在的情况。

在GET请求的参数中添加版本参数

class HomeView(APIView):  
    # /home/?version=1  
    # 将version参数赋值给request.version  
    versioning_class = URLPathVersioning

class BaseVersioning:
    default_version = api_settings.DEFAULT_VERSION
    # 在配置中设置有哪些可用的版本号
    allowed_versions = api_settings.ALLOWED_VERSIONS
    # 实际上,version_params参数定义在配置文件settings.py中
    version_param = api_settings.VERSION_PARAM

    def determine_version(self, request, *args, **kwargs):
        msg = '{cls}.determine_version() must be implemented.'
        raise NotImplementedError(msg.format(
            cls=self.__class__.__name__
        ))

    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
        return _reverse(viewname, args, kwargs, request, format, **extra)

    def is_allowed_version(self, version):
        if not self.allowed_versions:
            return True
        return ((version is not None and version == self.default_version) or
                (version in self.allowed_versions))

class QueryParameterVersioning(BaseVersioning):
    # GET /something/?version=0.1 HTTP/1.1
    # Host: example.com
    # Accept: application/json
    invalid_version_message = _('Invalid version in query parameter.')

    def determine_version(self, request, *args, **kwargs):
        # 本质上是从request.query_params,即GET请求参数中读取self.version_param键对应的信息
        # self.version_param 在其父类BaseVersioning定义
        # 如果version_param不存在,那么就去调用默认的default_version
        # 也就是REST_FRAMEWORK设置中的DEFAULT_VERSION参数
        version = request.query_params.get(self.version_param, self.default_version)
        # 取到版本信息后,判断是不是可支持的版本
        # 如果不是可行的版本信息,抛出一个异常,会在视图类的initial方法中被捕获
        if not self.is_allowed_version(version):
            raise exceptions.NotFound(self.invalid_version_message)
        return version

    # 反向生成url,本质上还是调用django的reverse
    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
        url = super().reverse(
            viewname, args, kwargs, request, format, **extra
        )
        if request.version is not None:
            return replace_query_param(url, self.version_param, request.version)
        return url

# /rest_framework/reverse.py
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
    if format is not None:
        kwargs = kwargs or {}
        kwargs['format'] = format
    # 本质上调用的是django的reverse
    url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
    if request:
        return request.build_absolute_uri(url)
    return url

具体的版本传参名称可以在REST_FRAMEWORK设置中定义,默认是version

REST_FRAMEWORK = {  
    'VERSION_PARAM': 'v',   # 设定的版本传参名称
    'DEFAULT_VERSION': 'version',  # 默认的版本传参名称
    'ALLOWED_VERSIONS': ['1.0', '2.0'],   # VERSION_PARAM 传参中的可行版本
}

这样,传参中的键v的值就是版本了,仍然赋值给request.version

请求在视图类的dispatch方法中被封装,并执行initial方法,其中的determine_version方法中将版本类实例化后执行该版本类实例的determine_version

QueryParameterVersioning视图类为例,它的determine_version方法会从request.query_params[self.version_param]取到具体的版本值,并调用self.is_allowed_version方法判断版本值是否合法。如果不合法,就抛出一个异常,被视图类的dispatch方法捕获。如果合法,就返回具体的版本值。

视图类的determine_version方法返回一个元组:(版本值,版本类的实例)。在视图类的initial方法中,版本值和版本类的实例被分别赋值给request.version, request.versioning_scheme

class APIView(View):
    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

    def dispatch(self, request, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        # 封装请求
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            # 在initial方法中解决版号问题
            self.initial(request, *args, **kwargs)
            ...
        
        # 捕获异常
        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
        
    def initial(self, request, *args, **kwargs):
        self.format_kwarg = self.get_format_suffix(**kwargs)
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # 将版本信息赋值给request.version
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme
        ...
       
    # 返回一个元组: (版本,版本类的实例)
    def determine_version(self, request, *args, **kwargs):
        # 如果未定义版本类,直接返回版本为空
        if self.versioning_class is None:
            return (None, None)
        # 实例化视图类定义的版本类,可能是默认设置,也可能是继承的类中的新设置
        scheme = self.versioning_class()
        # 执行版本类的determine_version方法
        return (scheme.determine_version(request, *args, **kwargs), scheme)

通过路由(url)匹配传递版本

通过请求头传递版本

class AcceptHeaderVersioning(BaseVersioning):
    # GET /something/ HTTP/1.1
    # Host: example.com
    # Accept: application/json; version=1.0
    invalid_version_message = _('Invalid version in "Accept" header.')

    def determine_version(self, request, *args, **kwargs):
        media_type = _MediaType(request.accepted_media_type)
        version = media_type.params.get(self.version_param, self.default_version)
        version = unicode_http_header(version)
        if not self.is_allowed_version(version):
            raise exceptions.NotAcceptable(self.invalid_version_message)
        return version

将版本信息写入到请求头的Accept: application/json;v=1.0version参数中(因为这里我们将VERSION_PARAM的名称设置为v)。

设置全局的默认版本类:

REST_FRAMEWORK = {  
    'VERSION_PARAM': 'v',  
    'DEFAULT_VERSION': 'version',  
    'ALLOWED_VERSIONS': ['v1', 'v2'],  
    'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.QueryParameterVersioning'  
}