DjangoRestFramework 在使用 Mysql 时获得记录总数过慢的问题

589 阅读3分钟

Django Rest Framework(以下简称DRF)在返回响应时, 如果结果是分页的, 则响应大致如下:

{
    count: 17,
    next: "...",
    previous: "...",
    results: []
}

其中count字段是由过滤后的, 未分页前的queryset.count()方法获得的, 也就是

SELECT COUNT(*AS `__count` FROM `table` WHERE ...

每次翻页, 这个count都会实时查询, 如果记录数较多, 每次使用mysqlcount(*)耗时就不容忽视.

为了获得准确的记录总数, 这是必要的, 但如果这个总数的精准度是不太重要的, 甚至记录数是长时间不变的, 那么将其缓存就是不错的解决方案.

另一个问题就是在页数过大时(例如取出第1200页), 响应明显变慢的问题, 实际也就是offset过大时查询变慢的问题, 可以这样解决: 只取出主键列进行offsetlimit, 然后通过主键值取出记录.

以下是我写的一些代码

# paginator.py
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
from collections import OrderedDict
from django.core.cache import cache


class FrontCustomPageSizePagination(PageNumberPagination):
    """
    继承PageNumberPagination是有必要的, DRF前端模板会用到比视图中更多的方法
    """
    
    def set_count(self, queryset, request, view=None):
        """
        使用缓存解决count(*)在数据量较大时耗时过长的问题
        一般而言`queryset`的过滤结果, 应该取决于查询参数和模型
        `set_count`会根据(`page`和`page_size`以外的)查询参数和对应的`model`名拼成一个键, 
        尝试在缓存中获取这个键的值, 如果没有就调用一次`count()`然后将其缓存起来, 过期时间1小时
        """
        query = dict(request.query_params)
        query.pop('page', None)
        query.pop('page_size', None)
        params = ':'.join(sorted(
            [
                f'{key}:{value[0] if isinstance(value, list) and len(value) == 1 else value}' 
                for key, value in query.items()
            ]
        ))
        cached_count_key = str(queryset.model.__name__) + ':' + params
        """
        例如 URL: /api/v1/users?province=江苏省&city=无锡市&page=1&page_size=100
        缓存键会是 User:province:江苏省:city:无锡市
        page和page_size不会影响到键, 因此在切换分页时可以共用缓存
        各个参数会重新排序, 因此URL中查询参数的位置调整并不会使缓存失效
        """
        self._count = cache.get(cached_count_key)
        if self._count == None:
            """
            此处写成`queryset.count()`, 和`queryset.values('pk').count()`
            实际执行的都会是`count(*)`, 因此并无区别.
            """
            self._count = queryset.count()
            cache.set(cached_count_key, self._count, 60 * 60)
    
    def paginate_queryset(self, queryset, request, view=None):
        self.set_count(queryset, request, view)
        """
        前端调整页数和页面大小的查询参数固定为 page 和 page_size
        并且 page_size 最大为 100, 如有需要在此处修改代码
        """
        page = int(request.query_params.get('page', 1))
        page_size = min(int(request.query_params.get('page_size', 100)), 100) 
        start = (page - 1) * page_size
        end = start + page_size
        """
        解决 offset 过大(即页数过大时)查询过慢的问题
        先只查询主键列(速度较快), 进行offset和limit, 然后通过主键值取出记录
        使用 list 是有必要的, 否则会报错
        This version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery'
        此处本有个if, 即希望在offset超过某个阈值后再用这种方式
        但实际应用中, 发现即使offset很小时(不超过数千)时这种方式也不会太慢
        """
        pks = list(queryset.values_list('pk', flat=True)[start:end])
        queryset = queryset.model.objects.filter(pk__in=pks)
        return queryset

    def get_paginated_response(self, data):
        """
        因为前端并没有使用next和previous, 这里直接取消了这两个字段
        """
        return Response(OrderedDict([
            ('count', self._count),
            ('results', data)
        ]))

然后将其作为某个APIViewpagination_class, 或应用至全局即可生效.