asyncio实现异步socket server

1,874 阅读4分钟

1.阻塞式socket编程

最基础的socket编程代码,分为服务端和客户端

server.py

import socket

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)

print("waiting for connection...")
sock, addr = server_socket.accept()
while True:
    data = sock.recv(1024)
    msg_data = data.decode("utf-8")
    if msg_data == "exit":
        sock.close()
        break
    print("receive: ", msg_data)
    msg = b"echo:" + data
    sock.send(msg)
print("server shut down...")

client.py

import socket

ip_port = ('127.0.0.1', 8000)

s = socket.socket()

s.connect(ip_port)

while True:
    inp = input("请输入要发送的信息: ").strip()
    if not inp:
        continue
    s.sendall(inp.encode())
    if inp == "exit":
        print("结束通信!")
        break
    server_reply = s.recv(1024).decode()
    print(server_reply)

s.close()

运行效果:

image-20220715145020387

可以看到,这种写法,同一时间只能与一个客户端接发消息,这是因为server_socket.accept()是阻塞的,当接收到客户端连接后,就进入死循环跟客户端收发数据了。

2.使用多线程

server.py

import socket
import threading

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)


def process_client_data(s):
    while True:
        data = s.recv(1024)
        msg_data = data.decode("utf-8")
        if msg_data == "exit":
            s.close()
            print("sock shut down...")
            break
        print("receive: ", msg_data)
        msg = b"echo:" + data
        s.send(msg)


def main():
    while True:
        print("waiting for connection...")
        sock, addr = server_socket.accept()
        print("receive connection from %s" % (addr,))
        threading.Thread(target=process_client_data, args=(sock,)).start()


if __name__ == '__main__':
    main()

客户端的代码没有变化,运行结果如下:

image-20220715151404835

看起来比较完美,可以同时跟多个客户端收发消息,但是可不可以不使用多线程(或者多进程),我们只在单一进程的单一线程里面达到这个效果呢?

3.使用select

server.py

from collections import defaultdict
import socket
from select import select

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)

msg_queue = defaultdict(list)
inputs = [server_socket]
outputs = []


def process_new_conn(s):
    c_sock, addr = s.accept()
    print("receive connection from %s" % (addr,))
    inputs.append(c_sock)


def process_receive_msg(s):
    data = s.recv(1024)
    msg_data = data.decode("utf-8")
    if msg_data == "exit":
        s.close()
        inputs.remove(s)
        print("sock shut down...")
        return
    print("receive: ", msg_data)
    msg = b"echo:" + data
    msg_queue[s].append(msg)
    outputs.append(s)


def process_send_msg(s):
    msgs = msg_queue[s]
    for msg in msgs:
        s.send(msg)
    msg_queue[s] = []
    outputs.remove(s)


def main():
    while True:
        rs, ws, es = select(inputs, outputs, inputs)
        for sock in rs:
            if sock is server_socket:
                process_new_conn(sock)
                continue
            process_receive_msg(sock)
        for sock in ws:
            process_send_msg(sock)


if __name__ == '__main__':
    main()

代码核心就是在循环中,不断的select取出可以读和写的fd对象(此处是socket),然后对其进行不同的处理。

  • process_new_conn:则将acceptsock加入inputs
  • process_receive_msg:有客户端发消息,生成回复的消息并加到msg_queue中,然后将sock加入到outputs
  • process_send_msg:需要给客户端回复消息,从msg_queue中取出消息并send后,将sockoutputs去除掉

4.使用poll

server.py

from collections import defaultdict
import socket
from select import poll, POLLIN, POLLOUT

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)

msg_queue = defaultdict(list)
socket_map = {server_socket.fileno(): server_socket}


def process_new_conn(s):
    c_sock, addr = s.accept()
    print("receive connection from %s" % (addr,))
    socket_map[c_sock.fileno()] = c_sock
    poll_obj.register(c_sock, POLLOUT)


def process_receive_msg(s):
    data = s.recv(1024)
    msg_data = data.decode("utf-8")
    if msg_data == "exit":
        s.close()
        print("sock shut down...")
        return
    print("receive: ", msg_data)
    msg = b"echo:" + data
    msg_queue[s].append(msg)
    poll_obj.modify(s, POLLOUT)


def process_send_msg(s):
    msgs = msg_queue[s]
    for msg in msgs:
        s.send(msg)
    msg_queue[s] = []
    poll_obj.modify(s, POLLIN)


poll_obj = poll()


def main():
    poll_obj.register(server_socket, POLLIN)
    while True:
        events = poll_obj.poll(30)
        for fd, mask in events:
            sock = socket_map[fd]
            if mask == POLLIN:
                if sock == server_socket:
                    process_new_conn(sock)
                else:
                    process_receive_msg(sock)
            if mask == POLLOUT:
                process_send_msg(sock)


if __name__ == '__main__':
    main()

poll提供了registermodifyunregisterpoll等方法,运行逻辑就是,如果是需要监测的事件,就将其register到poll对象中,然后通过poll方法不断取出就绪的fd,根据fd的事件类型(POLLIN, POLLOUT)对其进行不同的处理。

5.使用epoll

server.py

from collections import defaultdict
import socket
from select import epoll, EPOLLOUT, EPOLLIN, EPOLLHUP

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)

epoll_obj = epoll()
msg_queue = defaultdict(list)
socket_map = {server_socket.fileno(): server_socket}


def process_new_conn(s):
    c_sock, addr = s.accept()
    print("receive connection from %s" % (addr,))
    socket_map[c_sock.fileno()] = c_sock
    epoll_obj.register(c_sock, EPOLLOUT)


def process_receive_msg(s):
    data = s.recv(1024)
    msg_data = data.decode("utf-8")
    if msg_data == "exit":
        s.close()
        print("sock shut down...")
        return
    print("receive: ", msg_data)
    msg = b"echo:" + data
    msg_queue[s].append(msg)
    epoll_obj.modify(s, EPOLLOUT)


def process_send_msg(s):
    msgs = msg_queue[s]
    for msg in msgs:
        s.send(msg)
    msg_queue[s] = []
    epoll_obj.modify(s, EPOLLIN)


def main():
    epoll_obj.register(server_socket, EPOLLIN)
    while True:
        events = epoll_obj.poll(30)
        for fd, mask in events:
            sock = socket_map[fd]
            if mask == EPOLLIN:
                if sock == server_socket:
                    process_new_conn(sock)
                else:
                    process_receive_msg(sock)
            if mask == EPOLLOUT:
                process_send_msg(sock)
            if mask == EPOLLHUP:
                epoll_obj.unregister(fd)
                sock.close()


if __name__ == '__main__':
    main()

可以看到,使用方式上跟poll基本一致。

6.使用asyncio

server.py

import asyncio
import socket

loop = asyncio.get_event_loop()

server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ip_port = ('', 8000)
server_socket.bind(ip_port)
server_socket.listen(3)


def accept_conn(ser_sock):
    sock, addr = ser_sock.accept()
    print("receive connection from %s" % (addr,))
    loop.add_reader(sock, receive_data, sock)


def receive_data(sock):
    data = sock.recv(1024)
    print("receive:", data.decode("utf-8"))
    msg = b"echo:" + data
    loop.add_writer(sock, send_data, sock, msg)


def send_data(sock, msg):
    sock.send(msg)
    loop.remove_writer(sock)


loop.add_reader(server_socket, accept_conn, server_socket)

loop.run_forever()

这个才是咱们重点想要讲的,前面的select、poll、epoll实际上都被asyncio包了一层,asyncio里面event_loop实际上还是利用了这几种io多路复用技术。

具体如何利用的呢?咱们可以这样看

  • BaseEventLoop.run_forever

    image-20220715180424342

  • BaseEventLoop._run_once

        def _run_once(self):
            """Run one full iteration of the event loop.
    
            This calls all currently ready callbacks, polls for I/O,
            schedules the resulting callbacks, and finally schedules
            'call_later' callbacks.
            """
    
            sched_count = len(self._scheduled)
            if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and
                self._timer_cancelled_count / sched_count >
                    _MIN_CANCELLED_TIMER_HANDLES_FRACTION):
                # Remove delayed calls that were cancelled if their number
                # is too high
                new_scheduled = []
                for handle in self._scheduled:
                    if handle._cancelled:
                        handle._scheduled = False
                    else:
                        new_scheduled.append(handle)
    
                heapq.heapify(new_scheduled)
                self._scheduled = new_scheduled
                self._timer_cancelled_count = 0
            else:
                # Remove delayed calls that were cancelled from head of queue.
                while self._scheduled and self._scheduled[0]._cancelled:
                    self._timer_cancelled_count -= 1
                    handle = heapq.heappop(self._scheduled)
                    handle._scheduled = False
    
            timeout = None
            if self._ready or self._stopping:
                timeout = 0
            elif self._scheduled:
                # Compute the desired timeout.
                when = self._scheduled[0]._when
                timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT)
    
            event_list = self._selector.select(timeout)
            self._process_events(event_list)
    
            # Handle 'later' callbacks that are ready.
            end_time = self.time() + self._clock_resolution
            while self._scheduled:
                handle = self._scheduled[0]
                if handle._when >= end_time:
                    break
                handle = heapq.heappop(self._scheduled)
                handle._scheduled = False
                self._ready.append(handle)
    
            # This is the only place where callbacks are actually *called*.
            # All other places just add them to ready.
            # Note: We run all currently scheduled callbacks, but not any
            # callbacks scheduled by callbacks run this time around --
            # they will be run the next time (after another I/O poll).
            # Use an idiom that is thread-safe without using locks.
            ntodo = len(self._ready)
            for i in range(ntodo):
                handle = self._ready.popleft()
                if handle._cancelled:
                    continue
                if self._debug:
                    try:
                        self._current_handle = handle
                        t0 = self.time()
                        handle._run()
                        dt = self.time() - t0
                        if dt >= self.slow_callback_duration:
                            logger.warning('Executing %s took %.3f seconds',
                                           _format_handle(handle), dt)
                    finally:
                        self._current_handle = None
                else:
                    handle._run()
            handle = None  # Needed to break cycles when an exception occurs.
    

    注意这两行代码

    image-20220715180737111

  • BaseSelectorEventLoop._process_events

    image-20220715180930293

    看到了selectors、EVENT_READ和EVENT_WRITE这样代码,是否感觉跟前面三种有点相似呢?

  • BaseSelectorEventLoop.__init__

        def __init__(self, selector=None):
            super().__init__()
    
            if selector is None:
                selector = selectors.DefaultSelector()
            logger.debug('Using selector: %s', selector.__class__.__name__)
            self._selector = selector
            self._make_self_pipe()
            self._transports = weakref.WeakValueDictionary()
    

    可以看到self._selector

  • selectors.DefaultSelector()

    # Choose the best implementation, roughly:
    #    epoll|kqueue|devpoll > poll > select.
    # select() also can't accept a FD > FD_SETSIZE (usually around 1024)
    if 'KqueueSelector' in globals():
        DefaultSelector = KqueueSelector
    elif 'EpollSelector' in globals():
        DefaultSelector = EpollSelector
    elif 'DevpollSelector' in globals():
        DefaultSelector = DevpollSelector
    elif 'PollSelector' in globals():
        DefaultSelector = PollSelector
    else:
        DefaultSelector = SelectSelector
    

    最终可以看到,实际上asyncio的loop运行_process_events的时候,其实就是利用的IO多路复用里面的几个实现。