为什么运用异步
- 由于Djnago在4.0版本之后是支持异步且在4.1里的ORM以至不需求运用sync_to_async来装饰,rest框架是不支持的协程
- 异步根本是web效劳开发的趋向,像web框架有fastapi,sanic这些异步框架,Django作为老牌框架像异步迁移是必然的。
为什么要复写rest_framework局部函数
- 由于Django是协程,但是rest不是,会招致调用中途梗,python的协程代码一旦有同步代码则全程同步
代码
generics.py
import asyncio
import inspect
from asgiref.sync import sync_to_async
from rest_framework import exceptions, status
from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework.request import Request
from rest_framework.views import APIView
from tortoise import Model
from tortoise.queryset import QuerySet
from async_drf import mixins
class AsyncAPIView(APIView):
"""
参考:
https://github.com/encode/django-rest-framework/issues/7260
Provides async view compatible support for DRF Views and ViewSets.
This must be the first inherited class.
"""
@classmethod
def as_view(cls, *args, **initkwargs):
"""Make Django process the view as an async view."""
view = super().as_view(*args, **initkwargs)
async def async_view(*args, **kwargs):
return await view(*args, **kwargs)
async_view.cls = cls
async_view.initkwargs = initkwargs
async_view.csrf_exempt = True
return async_view
async def get_exception_handler_context(self):
return await sync_to_async(super().get_exception_handler_context)()
async def get_exception_handler(self):
return super().get_exception_handler()
async def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)):
auth_header = await sync_to_async(self.get_authenticate_header)(self.request)
if auth_header:
exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
exception_handler = await self.get_exception_handler()
context = await self.get_exception_handler_context()
if not asyncio.iscoroutinefunction(exception_handler):
response = await sync_to_async(exception_handler)(exc, context)
else:
response = await exception_handler(exc, context)
if response is None:
await sync_to_async(self.raise_uncaught_exception)(exc)
response.exception = True
return response
async def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context,
)
async def dispatch(self, request, *args, **kwargs):
"""Add async support."""
self.args = args
self.kwargs = kwargs
request = await self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers
try:
await sync_to_async(self.initial)(request, *args, **kwargs)
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
if not asyncio.iscoroutinefunction(handler):
handler = sync_to_async(handler)
response = await handler(request, *args, **kwargs)
except Exception as exc:
response = await self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
class AsyncGenericAPIView(AsyncAPIView, GenericAPIView):
async def paginate_queryset(self, queryset):
"""
Return a single page of results, or `None` if pagination is disabled.
"""
if self.paginator is None:
return None
return await self.paginator.paginate_queryset(queryset, self.request, view=self)
async def get_queryset(self) -> QuerySet:
"""
Get the list of items for this view.
This must be an iterable, and may be a queryset.
Defaults to using `self.queryset`.
This method should always be used rather than accessing `self.queryset`
directly, as `self.queryset` gets evaluated only once, and those results
are cached for all subsequent requests.
You may want to override this if you need to provide different
querysets depending on the incoming request.
(Eg. return a list of items that is specific to the user)
"""
assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method." % self.__class__.__name__
)
queryset = self.queryset
if isinstance(queryset, QuerySet):
queryset = queryset.all()
return queryset
async def get_object(self) -> Model:
"""
Returns the object the view is displaying.
You may want to override this if you need to provide non-standard
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
queryset = await self.filter_queryset(await self.get_queryset())
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
"Expected view %s to be called with a URL keyword argument "
'named "%s". Fix your URL conf, or set the `.lookup_field` '
"attribute on the view correctly." % (self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = await sync_to_async(get_object_or_404)(queryset, **filter_kwargs)
obj = await obj
await sync_to_async(self.check_object_permissions)(self.request, obj)
return obj
async def get_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
if inspect.iscoroutinefunction(self.get_serializer_class):
serializer_class = await self.get_serializer_class()
else:
serializer_class = await sync_to_async(self.get_serializer_class)()
kwargs.setdefault("context", await sync_to_async(self.get_serializer_context)())
return await sync_to_async(serializer_class)(*args, **kwargs)
async def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method." % self.__class__.__name__
)
return self.serializer_class
async def get_paginated_response(self, data):
"""
Return a paginated style `Response` object for the given output data.
"""
assert self.paginator is not None
return await self.paginator.get_paginated_response(data)
async def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
You are unlikely to want to override this method, although you may need
to call it either from a list view, or from a custom `get_object`
method if you want to apply the configured filtering backend to the
default queryset.
"""
for backend in list(self.filter_backends):
queryset = await backend().filter_queryset(self.request, queryset, self)
if inspect.iscoroutine(queryset):
return await queryset
return queryset
class AsyncCreateAPIView(mixins.AsyncCreateModelMixin, AsyncGenericAPIView):
"""
Concrete view for creating a model instance.
"""
async def post(self, request, *args, **kwargs):
return await self.create(request, *args, **kwargs)
class AsyncListAPIView(mixins.AsyncListModelMixin, AsyncGenericAPIView):
"""
Concrete view for listing a queryset.
"""
async def get(self, request, *args, **kwargs):
return await self.list(request, *args, **kwargs)
class AsyncRetrieveAPIView(mixins.AsyncRetrieveModelMixin, AsyncGenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
async def get(self, request, *args, **kwargs):
return await self.retrieve(request, *args, **kwargs)
class AsyncDestroyAPIView(mixins.AsyncDestroyModelMixin, AsyncGenericAPIView):
"""
Concrete view for deleting a model instance.
"""
async def delete(self, request, *args, **kwargs):
return await self.destroy(request, *args, **kwargs)
class AsyncUpdateAPIView(mixins.AsyncUpdateModelMixin, AsyncGenericAPIView):
"""
Concrete view for updating a model instance.
"""
async def put(self, request, *args, **kwargs):
return await self.update(request, *args, **kwargs)
async def patch(self, request, *args, **kwargs):
return await self.partial_update(request, *args, **kwargs)
class AsyncListCreateAPIView(mixins.AsyncListModelMixin, mixins.AsyncCreateModelMixin, AsyncGenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
async def get(self, request, *args, **kwargs):
return await self.list(request, *args, **kwargs)
async def post(self, request, *args, **kwargs):
return await self.create(request, *args, **kwargs)
class AsyncRetrieveUpdateAPIView(mixins.AsyncRetrieveModelMixin, mixins.AsyncUpdateModelMixin, AsyncGenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
async def get(self, request, *args, **kwargs):
return await self.retrieve(request, *args, **kwargs)
async def put(self, request, *args, **kwargs):
return await self.update(request, *args, **kwargs)
async def patch(self, request, *args, **kwargs):
return await self.partial_update(request, *args, **kwargs)
class AsyncRetrieveDestroyAPIView(mixins.AsyncRetrieveModelMixin, mixins.AsyncDestroyModelMixin, AsyncGenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
async def get(self, request, *args, **kwargs):
return await self.retrieve(request, *args, **kwargs)
async def delete(self, request, *args, **kwargs):
return await self.destroy(request, *args, **kwargs)
class AsyncRetrieveUpdateDestroyAPIView(
mixins.AsyncRetrieveModelMixin, mixins.AsyncUpdateModelMixin, mixins.AsyncDestroyModelMixin, AsyncGenericAPIView
):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
async def get(self, request, *args, **kwargs):
return await self.retrieve(request, *args, **kwargs)
async def put(self, request, *args, **kwargs):
return await self.update(request, *args, **kwargs)
async def patch(self, request, *args, **kwargs):
return await self.partial_update(request, *args, **kwargs)
async def delete(self, request, *args, **kwargs):
return await self.destroy(request, *args, **kwargs)
mxins.py
from typing import Dict
from asgiref.sync import sync_to_async
from rest_framework import status
from rest_framework.mixins import (
CreateModelMixin,
DestroyModelMixin,
ListModelMixin,
RetrieveModelMixin,
UpdateModelMixin,
)
from rest_framework.response import Response
from rest_framework.settings import api_settings
class AsyncCreateModelMixin(CreateModelMixin):
"""Make `create()` and `perform_create()` overridable.
Without inheriting this class, the event loop can't be used in these two methods when override them.
This must be inherited before `CreateModelMixin`.
class MyViewSet(AsyncMixin, GenericViewSet, AsyncCreateModelMixin, CreateModelMixin):
pass
"""
async def create(self, request, *args, **kwargs):
serializer = await self.get_serializer(data=request.data)
await sync_to_async(serializer.is_valid)(raise_exception=True)
await self.perform_create(serializer)
headers = await self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
async def perform_create(self, serializer):
await sync_to_async(serializer.save)()
async def get_success_headers(self, data) -> Dict:
try:
return {"Location": str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
class AsyncListModelMixin(ListModelMixin):
async def list(self, request, *args, **kwargs):
queryset = await self.filter_queryset(await self.get_queryset())
page = await self.paginate_queryset(queryset)
if page is not None:
serializer = await self.get_serializer(page, many=True)
return await self.get_paginated_response(serializer.data)
serializer = await self.get_serializer(queryset, many=True)
return Response(serializer.data)
class AsyncRetrieveModelMixin(RetrieveModelMixin):
"""
Retrieve a model instance.
"""
async def retrieve(self, request, *args, **kwargs):
instance = await self.get_object()
serializer = await self.get_serializer(instance)
return Response(serializer.data)
class AsyncUpdateModelMixin(UpdateModelMixin):
"""
Update a model instance.
"""
async def update(self, request, *args, **kwargs):
partial = kwargs.pop("partial", False)
instance = await self.get_object()
serializer = await self.get_serializer(instance, data=request.data, partial=partial)
await sync_to_async(serializer.is_valid)(raise_exception=True)
await self.perform_update(serializer)
if getattr(instance, "_prefetched_objects_cache", None):
instance._prefetched_objects_cache = {}
return Response(serializer.data)
async def perform_update(self, serializer):
await sync_to_async(serializer.save)()
async def partial_update(self, request, *args, **kwargs):
kwargs["partial"] = True
return await self.update(request, *args, **kwargs)
class AsyncDestroyModelMixin(DestroyModelMixin):
"""Make `destroy()` and `perform_destroy()` overridable.
Without inheriting this class, the event loop can't be used in these two methods when override them.
This must be inherited before `DestroyModelMixin`.
class MyViewSet(AsyncMixin, GenericViewSet, AsyncDestroyModelMixin, DestroyModelMixin):
pass
"""
async def destroy(self, request, *args, **kwargs):
instance = await self.get_object()
await self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)
async def perform_destroy(self, instance):
await sync_to_async(instance.delete)()
运用
class ProFileUserViewList(AsyncGenericAPIView):
queryset = ProfileUser.objects.all()
def get_serializer_class(self):
if self.request.method == "GET":
return ProfileUserListSerializer
async def get(self, request):
"""
用户列表
:return:
"""
结语
- 这只是完成视图步, 还短少filter,page,swagger等调用步
- 方才还看一下github的issues 想要rest支持异步还要不少时间,毕竟django从3.0到4.0也是经过5年
- 借用github issues一句话,Async IO doesn't necessarily speeds up your code, it just reduces the server resources usage for concurrent tasks by leveraging concurrent network IO in a single-threaded application.
- 以上只是个人研讨源码修正完成,我用的数据库是tortoise,但在django4.0上运用也是没有问题