drf-版本组件以及源码分析

131 阅读3分钟

drf-版本组件以及源码分析

大致了解

版本组件中一般使用三个类来进行版本控制,分别是:

  • GET参数传递的:QueryParameterVersioning
  • URL路径传递的URLPathVersioning
  • 请求头传递的:AcceptHeaderVersioning

image-20210819155617845

而这三个类都遵循着三个

REST_FRAMEWORK = {
    "DEFAULT_VERSION":"v1",         # 默认版本
    "ALLOWED_VERSIONS":['v1','v2'], # 允许的版本
    'VERSION_PARAM':'version'       # URL中获取值的key
}

局部使用(以URL路径传递为例)

url中写version

urlpatterns = [
    re_path(r'^(?P<version>[v1|v2]+)/order/$', views.OrderView.as_view(), name='order'),
    re_path(r'^(?P<version>\w+)/order/$', views.OrderView.as_view(), name='order'),
]

视图中应用

在CBV中添上versioning_class = URLPathVersioning

from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
​
​
class OrderView(APIView):
    versioning_class = URLPathVersioning
​
    def get(self, request, *args, **kwargs):
        print(request.version)
        print(request.versioning_scheme)
        # 反向生成url,需要填入视图name和request
        print(request.versioning_scheme.reverse('order',request=request))
        return Response({'msg': request.version})
​
    def post(self,request,*args,**kwargs):
        return Response({'msg': request.version})

settings中配置

限制版本数,以及url读取正则的组名

REST_FRAMEWORK = {
    "DEFAULT_VERSION":"v1",         # 默认版本
    "ALLOWED_VERSIONS":['v1','v2'], # 允许的版本
    'VERSION_PARAM':'version'       # URL中获取值的key
}

全局使用(推荐)

url中写version

urlpatterns = [
    re_path(r'^(?P<version>[v1|v2]+)/order/$', views.OrderView.as_view(), name='order'),
    re_path(r'^(?P<version>\w+)/order/$', views.OrderView.as_view(), name='order'),
]

视图中应用

当使用全局配置时,就不需要在自定义的视图类中使用versioning_class = URLPathVersioning这段代码了。

from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
​
​
class OrderView(APIView):
    def get(self, request, *args, **kwargs):
        print(request.version)
        print(request.versioning_scheme)
        print(request.versioning_scheme.reverse('order',request=request))
        return Response({'msg': request.version})
​
    def post(self,request,*args,**kwargs):
        return Response({'msg': request.version})

settings中配置

直接全局里配置,子类不配置,在父类APIview里可以找到versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

REST_FRAMEWORK = {
    "DEFAULT_VERSIONING_CLASS":"rest_framework.versioning.URLPathVersioning",
    "ALLOWED_VERSIONS":['v1','v2'],
    'VERSION_PARAM':'version',
    "DEFAULT_VERSION":"v1"
}

怎么灵巧地设置全局setting-版本

获取REST_FRAMEWORK = {}的key值

点进 APIview的源码找到【DEFAULT_VERSIONING_CLASS】

versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

再点【api_settings】找到【REST_FRAMEWORK】

def reload_api_settings(*args, **kwargs):
    setting = kwargs['setting']
    if setting == 'REST_FRAMEWORK':
        api_settings.reload()
REST_FRAMEWORK = {
    "DEFAULT_VERSIONING_CLASS": ""
}

获取类的路径(value值)

导入想要用的版本类来获得类的路径

from rest_framework.versioning import URLPathVersioning
REST_FRAMEWORK = {
    "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning"
}

填入允许的版本、URL中获取值的key、默认版本

填入运行存在的版本,点进去【URLPathVersioning】找到【BaseVersioning】中的

class BaseVersioning:
    default_version = api_settings.DEFAULT_VERSION
    allowed_versions = api_settings.ALLOWED_VERSIONS
    version_param = api_settings.VERSION_PARAM

源码简略分析

第一部分:

image-20220918212923895

第二部分:

image-20220918213611704

源码逐步分析

源码分析例子类是QueryParameterVersioning,其他同理,只是获取版本值处的代码稍有不同

第一部分:request.versionrequest.versioning_scheme的校验和赋值

①、请求进来是执行CBV的dispatch方法,这次依旧只需要关注initial方法

image-20220913095013273

②、进入initial方法,中间那两行代码就是drf的版本控制源码,目的是将获取到的版本值和版本类对象分别赋值给request.versionrequest.versioning_scheme

image-20220918204104846

③、进入到self.determine_version方法中

  • 如果没有局部设置或者全局设置版本组件的类,那么就会给request.versionrequest.versioning_scheme返回None
  • 若有指定的类,那么将会将scheme.determine_version的返回值赋值给request.version,版本类实例化对象赋值给request.versioning_scheme
def determine_version(self, request, *args, **kwargs):
    if self.versioning_class is None:
    	return (None, None)
    scheme = self.versioning_class()
    return (scheme.determine_version(request, *args, **kwargs), scheme)

④、进入到QueryParameterVersioning版本类中request.versioning_scheme方法

  • 第一行是在通过VERSION_PARAM(版本值的key)来获取GET参数中传递的版本值,如果没有则是 default_versio(默认版本值)
  • 第二行是如果获取到的版本值没有在允许的版本中,则报错。
  • 若是在允许的版本内,则返回这个获取到的version值给request.version
class QueryParameterVersioning(BaseVersioning):
    """
    GET /something/?version=0.1 HTTP/1.1
    Host: example.com
    Accept: application/json
    """
    invalid_version_message = _('Invalid version in query parameter.')

    def determine_version(self, request, *args, **kwargs):
        version = request.query_params.get(self.version_param, self.default_version)
        if not self.is_allowed_version(version):
            raise exceptions.NotFound(self.invalid_version_message)
        return version

⑤、进入第二行的self.is_allowed_version方法,在【BaseVersioning】类中

  • 如果没有指定允许的版本值,那么返回True
  • 若有则,如果版本值不为空且为默认版本,返回True
  • 又或者版本值在允许的版本内,则返回True
def is_allowed_version(self, version):
    if not self.allowed_versions:
        return True
    return ((version is not None and version == self.default_version) or
            (version in self.allowed_versions))

第二部分:反向生成

①进入QueryParameterVersioning类中的reverse方法

def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
    url = super().reverse(
        viewname, args, kwargs, request, format, **extra
    )
    if request.version is not None:
        return replace_query_param(url, self.version_param, request.version)
    return url

②进入父类BaseVersioning

def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
    return _reverse(viewname, args, kwargs, request, format, **extra)

③、进入 _reverse,这里只需要关注这个django_reverse方法,其实这个方法就是django默认的反向生成函数

image-20220918210701749

④、回到①中

  • 如果版本为空,则返回这个反向生成的url
  • 没有的话,则在replace_query_param方法中加入版本重新生成对应的url

image-20220918210807369