Skynet socket.read 方法解析

1,277 阅读9分钟

最近学习skynet,发现其他方面的分析文章都挺多,但是关于socket的很少,所以想尝试写一点。今天先来分析一下socket.read 方法

socket.read(id, sz) 

我们首先看一下官方文档是怎么写的: 从一个 socket 上读 sz 指定的字节数。如果读到了指定长度的字符串,它把这个字符串返回。如果连接断开导致字节数不够,将返回一个 false 加上读到的字符串。如果 sz 为 nil ,则返回尽可能多的字节数,但至少读一个字节(若无新数据,会阻塞)。 所谓阻塞模式,实际上是利用了 lua 的 coroutine 机制。当你调用 socket api 时,服务有可能被挂起(时间片被让给其他业务处理),待结果通过 socket 消息返回,coroutine 将延续执行。

分析read方法有两个重点,一是同步调用的原理,二是数据的存取方式。

这里主要涉及到socket.lua 和 lua-socket.c 两个文件

我们首先来看一下lua层

lua层处理

function socket.read(id, sz)
	local s = socket_pool[id]

	local ret = driver.pop(s.buffer, s.pool, sz) --调用lpopbuffer
	if ret then --如果返回的不是nil,表示读取到了需要的数据,直接返回
		return ret
	end
	--否则就只有暂停当前协程,直到从对端发来足够的数据后再恢复执行
	--在暂停之前会判断对端是不是已经断开连接了,如果是的话,就不暂停当前协程了,直接返回目前已经读到的所有数据
	if not s.connected then
		return false, driver.readall(s.buffer, s.pool)
	end

	s.read_required = sz --保存需要读取的size,这样当对端发送过来更多的数据时我们才能判断够不够
	suspend(s) --挂起当前协程,当数据够或者连接断开时恢复协程
	ret = driver.pop(s.buffer, s.pool, sz)
	if ret then
		return ret
	else
		return false, driver.readall(s.buffer, s.pool)
	end
end

local function suspend(s)
	s.co = coroutine.running() --保存协程
	skynet.wait(s.co) --让出当前协程执行权,再次恢复时也是从这里开始执行
end

上面的代码删除了一些额外判断,我们只看一下主要的流程,可以看到socket.read 首先尝试读取指定字节数的数据(这里调用了 lua-socket.c 的 lpopbuffer 方法,我们后面具体分析),如果读取到了,就直接返回,否则就让出当前协程的执行权,直到对端发来了更多的数据满足了我们的要求,才会重新恢复当前协程的执行。我们看到这里s用read_required和co字段分别保存了要求的字节数和当前运行的协程,这样当收到数据时我们才能够判断够不够以及应该恢复哪个协程。

这里skynet.wait 方法 就不分析了,如果感兴趣的话可以参考 zhuanlan.zhihu.com/p/84653538 这篇文章

下面我们看一下接收到新数据后的处理

-- read skynet_socket.h for these macro
-- SKYNET_SOCKET_TYPE_DATA = 1
socket_message[1] = function(id, size, data) 
	local s = socket_pool[id]
	local sz = driver.push(s.buffer, s.pool, data, size) --调用lpushbuffer,返回未读取的数据长度
	local rr = s.read_required --取出要读取的数据长度
	local rrt = type(rr)
	if rrt == "number" then
		if sz >= rr then --如果有足够的数据可读
			s.read_required = nil
			wakeup(s) --唤起调用read的协程
		end
	else
	end
end

local function wakeup(s)
	local co = s.co
	if co then
		s.co = nil
		skynet.wakeup(co)
	end
end

当接收到新数据后,上面的函数会作为回调函数被调用。首先会把数据保存起来,然后会取出我们在上面暂停读取协程时存的要求的字节数,如果判断目前能够读取的数据长度已经满足要求了,就唤醒睡眠的读取协程

c层处理

在具体看读取及写入方法前,我们先来看一下存储网络数据包的结构

//存放所有未读取的数据
struct socket_buffer {
    int size;                 // 还未读取的网络数据总长度
    int offset;               // head 已读数据的偏移
    struct buffer_node *head; // 数据buff_node链表的头部指针
    struct buffer_node *tail;
};

struct buffer_node {
    char *msg;
    int sz;                   //该buffer_node存储的数据size
    struct buffer_node *next; 
};

可以看到 socket_buffer 就是一个链表结构,当接收到新的网络数据时,就会新构造一个节点插入到尾部,每次读取都从头节点开始读,如果一次读不完,就在 offset 记录头节点还有多少数据没读,当头节点的数据全都读完后,就移除头结点,将下一个节点重新设为头结点

接下来我们看一下读取的方法

lpopbuffer 读取数据

/*
	userdata send_buffer
	table pool
	integer sz 
 */
static int
lpopbuffer(lua_State *L) {
    struct socket_buffer *sb = lua_touserdata(L, 1);
    if (sb == NULL) {
        return luaL_error(L, "Need buffer object at param 1");
    }
    luaL_checktype(L, 2, LUA_TTABLE);
    int sz = luaL_checkinteger(L, 3);
    if (sb->size < sz || sz == 0) { //如果此次没有足够的数据可读
        lua_pushnil(L);             //直接返回一个nil
    } else {
        pop_lstring(L, sb, sz, 0); //从sb中取出sz字节的数据
        sb->size -= sz;            //更新size
    }
    lua_pushinteger(L, sb->size); //同时返回剩下可读的字节数

    return 2;
}

可以看到 lpopbuffer 首先判断可读的字节数够不够,如果不够直接返回nil,如果够的话就取出需要长度的数据,除了返回数据之外还会额外返回剩下可读的字节数。接下来我们进到pop_lstring 方法,看看具体是如何从sb中取出数据的。

static void
pop_lstring(lua_State *L, struct socket_buffer *sb, int sz, int skip) {
    struct buffer_node *current = sb->head; //取出头节点
    if (sz < current->sz - sb->offset) {    //如果有足够数据,并且还有多的
        lua_pushlstring(L, current->msg + sb->offset, sz - skip);
        sb->offset += sz;
        return;
    }
    if (sz == current->sz - sb->offset) { //刚好够(此时就需要移除头结点了)
        lua_pushlstring(L, current->msg + sb->offset, sz - skip);
        return_free_node(L, 2, sb); //移除头结点
        return;
    }

    luaL_Buffer b;
    luaL_buffinitsize(L, &b, sz);
    for (;;) {
        int bytes = current->sz - sb->offset; //sb当前节点可读的数据
        if (bytes >= sz) {                    //如果当前节点的数据够了,就读取需要的,然后退出循环返回
            if (sz > skip) {
                luaL_addlstring(&b, current->msg + sb->offset, sz - skip);
            }
            sb->offset += sz;
            if (bytes == sz) {
                return_free_node(L, 2, sb);
            }
            break;
        }
        //如果当前节点的数据不够,就全部读完,更新当前节点为下一个节点,继续循环
        int real_sz = sz - skip;
        if (real_sz > 0) {
            luaL_addlstring(&b, current->msg + sb->offset, (real_sz < bytes) ? real_sz : bytes);
        }
        return_free_node(L, 2, sb);
        sz -= bytes; //更新需要读取的数据
        if (sz == 0) //需要读取的数据为0,表示已经读够了,退出循环
            break;
        current = sb->head; //更新current节点为新的头结点
        assert(current);
    }
    luaL_pushresult(&b);
}

可以看到 pop_lstring 从头结点开始读,直到读到需要的字节数后返回。如果头结点剩余的数据刚好够,那么只用读取头结点就可以了,否则就需要从头结点开始依次读取链表的节点。从上一步看到传过来的skip参数为0,所以在我们的场景中可以直接忽略这个参数。值得注意的是,每次头结点的数据读取完毕后,我们都要进行 sb 链表的更新,将头结点的下一个节点设为新的头结点。这里更新链表的return_free_node 方法还涉及到读写缓冲区。我们在后面具体分析。

我们再看一下写入的方法

lpushbuffer 写入数据

/*
	userdata send_buffer
	table pool
	lightuserdata msg
 */
	int size
static int
lpushbuffer(lua_State *L) {
    struct socket_buffer *sb = lua_touserdata(L, 1);
    if (sb == NULL) {
        return luaL_error(L, "need buffer object at param 1");
    }
    char *msg = lua_touserdata(L, 3);
    if (msg == NULL) {
        return luaL_error(L, "need message block at param 3");
    }
    int pool_index = 2;
    luaL_checktype(L, pool_index, LUA_TTABLE);
    int sz = luaL_checkinteger(L, 4);

    lua_rawgeti(L, pool_index, 1);                         //把pool[1]压栈
    struct buffer_node *free_node = lua_touserdata(L, -1); // 拿到pool[1]处存的 free_node

    //将网络数据的指针和大小保存在这个空闲的 buff_node上
    free_node->msg = msg;
    free_node->sz = sz;
    free_node->next = NULL;

    //把存储了数据的 buff_node 插入到sb的尾部
    if (sb->head == NULL) {
        assert(sb->tail == NULL);
        sb->head = sb->tail = free_node;
    } else {
        sb->tail->next = free_node;
        sb->tail = free_node;
    }
    sb->size += sz;

    lua_pushinteger(L, sb->size);

    return 1;
}

可以看到 lpushbuffer 首先会从缓冲池里拿到一个空闲节点free_node(先不管怎么拿到的),然后把数据保存在这个free_node 中,最后把它插入到sb 的尾部。如此就完成了数据的存放。

读写缓冲区pool

首先我们先想一下我们具体缓冲的是什么。通过上面代码我们可以看到,我们每次接收到要存储的数据都是一个指针加一个size(因为都是在一个进程内,所以可以直接传递指针),可以直接使用buffer_node来存储。但是我们不想每次有数据来都新建一个buffer_node,当buffer_node里数据读完之后再销毁这个buffer_node,我们希望可以复用它。如此我们就需要有一个buffer_node池,来数据时就从池里取一个buffer_node,当某个buffer_node读完时就把它归还到池中。

在skynet这个缓冲pool是通过lua的table 实现的,但是不管是设置还是使用都是在c代码中,每次都是通过lua传参s.pool传给c函数。

在具体看这个table的结构之前,我们先来分析一下应该怎样实现。 首先table里面应该存了很多buffer_node,这些buffer_node里有已经存放数据的,有没存放数据作为空闲节点存在的。为了更方便的找到空闲节点,我们需要一个额外的free_node指针(为了方便,池里面使用buffer_node结构来存放它)来指向空闲节点,而每个空闲节点又有next指针指向下一个空闲节点,如果就可以很方便的对空闲节点做增删了

既然是缓冲池,那当接收的数据过多,缓冲都被用完之后,就需要扩大缓冲池,也就是是增加buffer_node, 那每次增加多少个buffer_node合适呢,太多也不好,太少也不好,参考其他的扩容策略,每次增加值都是上一次增加值的两倍是最为合适的。但是首先我们要知道上次增加了多少,目前pool是保存在lua代码里的,如果这个值也保存在lua代码里,那么每次调用c函数时都要传递进去,调用之后还要传递回来,就很不方便。其实我们可以巧妙的利用lua table的容量,每次增加时我们就新建一个表项,通过这个表项的index来计算出此次需要增加多少个buffer_node

pool = {
	[1] = free_node, -- 指向下面31个表项中某个buff_node,作为空闲链表的头结点
	[2] = buffer_node_pool, -- 存放16个buff_node
	[3] = buffer_node_pool, -- 存放32个buff_node
	...
	[32] = buffer_node_pool, -- 存放4096个buff_node
}

最终结构应该如上所示,注意下面的表项一开始都是不存在的。

我们再来看一下lpushbuffer方法,这次只关注和缓冲池有关的。

static int
lpushbuffer(lua_State *L) {
    int pool_index = 2;
    luaL_checktype(L, pool_index, LUA_TTABLE);

    lua_rawgeti(L, pool_index, 1);                         //把pool[1]压栈
    struct buffer_node *free_node = lua_touserdata(L, -1); // 拿到pool[1]处存的 free_node
    lua_pop(L, 1); //弹出pool[1]
    if (free_node == NULL) {                 //为 null,则表示已经没有空闲的 buff_node
        int tsz = lua_rawlen(L, pool_index); //拿到作为pool的table的元素个数
        if (tsz == 0)
            tsz++;
        int size = 8;
        //每次*2,达到4096就不再增了
        if (tsz <= LARGE_PAGE_NODE - 3) {
            size <<= tsz;
        } else {
            size <<= LARGE_PAGE_NODE - 3;
        }
        lnewpool(L, size);                   //初始化这个表项存储的所有buffer_node
        free_node = lua_touserdata(L, -1);   //free_node指向这个表项的第一个元素
        lua_rawseti(L, pool_index, tsz + 1); //弹出栈顶值并赋值给pool[tsz+1]
    }
    lua_pushlightuserdata(L, free_node->next); // 将free_node 指向的下一个空闲 buff_node 指针赋值到 pool[1]
    lua_rawseti(L, pool_index, 1);             // pool[1]重新赋值到pool里
}

static int
lnewpool(lua_State *L, int sz) {
    struct buffer_node *pool = lua_newuserdatauv(L, sizeof(struct buffer_node) * sz, 0); //新建pool的一个表项,并放到栈顶
    int i;
    for (i = 0; i < sz; i++) {
        pool[i].msg = NULL;
        pool[i].sz = 0;
        pool[i].next = &pool[i + 1];
    }
    pool[sz - 1].next = NULL;
    if (luaL_newmetatable(L, "buffer_pool")) { //只有第一次执行时,表不存在才会返回1,不过存不存在都会把表压栈
        lua_pushcfunction(L, lfreepool);
        lua_setfield(L, -2, "__gc");
    }
    lua_setmetatable(L, -2); //把一张表弹出栈,并将其设为给定索引处的值的元表。
    return 1;
}

可以看到空闲节点使用完之后就进行了扩容,扩容之后就把新增buffer_node中的首个赋值给free_node。另外在扩容的时候,如果发现是首次使用,就重设了__gc元方法。

接下来再看一下sb头结点的数据读取完之后调用的return_free_node

//取出sb头结点,清空,并将它塞入到pool[1]里作为新的free_node结点(重新放回到缓冲池里),原来的free_node节点作为它的下一节点
static void
return_free_node(lua_State *L, int pool, struct socket_buffer *sb) {
    struct buffer_node *free_node = sb->head; //把head设为free_node
    //以下是调整sb链表
    sb->offset = 0;
    sb->head = free_node->next;
    if (sb->head == NULL) {
        sb->tail = NULL;
    }
    lua_rawgeti(L, pool, 1);                 //pool[1]压栈
    free_node->next = lua_touserdata(L, -1); //原来的free_node成为新的free_node的next节点
    lua_pop(L, 1);                           //弹出pool[1]
    skynet_free(free_node->msg);
    free_node->msg = NULL;

    free_node->sz = 0;
    lua_pushlightuserdata(L, free_node); //把新free_node压栈
    lua_rawseti(L, pool, 1);             //弹出刚刚压栈的free_node,并设置为pool[1]
}

可以看到return_free_node 主要做的就是把读取完的head节点重新放到缓冲池里作为新的free_node(即更新空闲链表的头结点)

参考: