手写NSQ(一)

1,186

nsq 是 go 语言体系下一款非常著名的消息队列。它本身非常容易上手,文档写的也很清晰,同时它的源码也并不复杂。这里不能不提一下,使用 go 的通道来编写消息队列的话是真的方便。我们从生产者接收消息后为了效率考虑肯定是要存到内存中,如果使用共享内存的话,比如Java中的容器,那么存入消息和取出消息都需要加锁;但是在 go 语言中,我们就不需要加锁,一个协程往通道里扔,另一个协程从通道里取即可,非常方便。

为了更深入的理解消息队列,我决定用一个系列文章来实现一个迷你版的nsq。本篇文章将会实现nsq最基本的功能:生产者发布消息和消费者订阅接收消息。后续的每一版都会在前一版的基础进行完善。

在看代码之前,我们先来回顾一下nsq的消息传递过程。

nsq的消息传递过程

18ca4ea10aca.gif

连接到nsq的客户端有两种角色:生产者或者消费者,当然也有可能既是生产者也是消费者。
如上图,生产者会往指定的 topic 中发送 message,topic 则会将 message 复制给它下面的所有channel ,然后 channel 再将消息发送给订阅的消费者,如果多个消费者订阅同一个 channel,那么该 channel 中每个 message 将被发送给一个随机的消费者。

mini-nsq 的工作流程

  1. 开启一个 tcp 服务监听来自客户端的连接请求
  2. 为每一个成功连接的客户端创建一个工作协程来处理来自客户端的消息(包括消费者的订阅,生产者发布消息等)
  3. 生产者发送过来的消息会暂存到对应的topic中,同时每个topic都有一个工作协程将其中暂存的的消息发送到订阅了该topic的channel里
  4. 为每一个成功连接的客户端再创建一个工作协程负责将绑定的channel 中的 message 发送给该客户端

代码分析

如下图,代码主要分为两部分,apps下的是测试程序,nsqd下的则是我们的工程文件

Snipaste_2021-05-09_14-06-30.png

基础组件(Message,Topic,Channel)

message.go

type MessageID [MsgIDLength]byte

const (
	MsgIDLength       = 16
)

type Message struct {
	ID        MessageID
	Body      []byte
}

func NewMessage(id MessageID, body []byte) *Message {
	return &Message{
		ID:        id,
		Body:      body,
	}
}

func (m *Message) Bytes() ([]byte,error) {
	buf := new(bytes.Buffer)
	_, err := buf.Write(m.ID[:])
	if err != nil {
		return nil, err
	}
	_, err = buf.Write(m.Body)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

可以看到 message 由两部分组成:16字节的 ID 和 Body;Bytes 方法用于将 message 转换为字节切片以进行最终的网络传输。

topic.go

type Topic struct {
	name              string

	sync.RWMutex 	//为了保证channelMap的安全
	channelMap        map[string]*Channel //记录该 topic 下的所有 channel

	memoryMsgChan     chan *Message //暂存发送到该 topic 中的 message
	channelUpdateChan chan int  //当该 topic 下的 channel 发生变动时,用于通知
}

func NewTopic(topicName string) *Topic {
	t := &Topic{
		name:              topicName,
		channelMap:        make(map[string]*Channel),
		memoryMsgChan:     make(chan *Message, 10000),
		channelUpdateChan: make(chan int),
	}
	go t.messagePump() //开启工作协程
	return t
}

Topic 结构体中的字段具体含义见注释。注意当我们 new 一个 topic 时会为该 topic 开启一个工作协程用于将该topic 中暂存的 message 传递给该 topic 下的 channel 中,如下

func (t *Topic) messagePump() {
	var msg *Message
	var chans []*Channel
	var memoryMsgChan chan *Message

	t.Lock()
	for _, c := range t.channelMap {
		chans = append(chans, c)
	}
	t.Unlock()

	if len(chans) > 0 {
		memoryMsgChan = t.memoryMsgChan
	}

	// main message loop
	for {
		select {
		case msg = <-memoryMsgChan:
			//到这里的时候chans的数量必须 >0,否则消息就丢失了,
			//所以我们处理时会在chans为 0的时候将memoryMsgChan置为nil
			for _, channel := range chans {
				err := channel.PutMessage(msg)
				if err != nil {
					log.Printf(
						"TOPIC(%s) ERROR: failed to put msg(%s) to channel(%s) - %s",
						t.name, msg.ID, channel.name, err)
				}
			}
		case <-t.channelUpdateChan:
			log.Println("topic 更新 channel")
			chans = chans[:0]
			t.Lock()
			for _, c := range t.channelMap {
				chans = append(chans, c)
			}
			t.Unlock()
			if len(chans) == 0 {
				memoryMsgChan = nil
			} else {
				memoryMsgChan = t.memoryMsgChan
			}
		}
	}
}

我们用 chans 来保存该 topic 下的所有 channel,因为channel是动态增减的,当 channel 变动时,channelUpdateChan 就会传递该消息,我们就需要重新遍历 channelMap 给 chans 赋值。注意如果 topic 下目前没有任何 channel,此时我们千万不能从 t.memoryMsgChan 中取出 message ,否则就会造成消息的丢失。这里用到的一个技巧就是使用了一个中间变量 memoryMsgChan,当 topic 下 channel 数量为0时,将 memoryMsgChan 置为nil,这样 select 就不会进入该分支。

其他的一些方法

func (t *Topic) PutMessage(m *Message) error {
	log.Printf("message 进入 topic")
	t.memoryMsgChan <- m
	return nil
}

func (t *Topic) GetChannel(channelName string) *Channel {
	t.Lock()
	channel, isNew := t.getOrCreateChannel(channelName)
	t.Unlock()
	if isNew {
		t.channelUpdateChan <- 1
	}
	return channel
}

// this expects the caller to handle locking
func (t *Topic) getOrCreateChannel(channelName string) (*Channel, bool) {
	channel, ok := t.channelMap[channelName]
	if !ok {
		channel = NewChannel(t.name, channelName)
		t.channelMap[channelName] = channel
		log.Printf("TOPIC(%s): new channel(%s)", t.name, channel.name)
		return channel, true
	}
	return channel, false
}


func (t *Topic) GenerateID() MessageID {
	var h MessageID
	return h
}

PutMessage 负责往 topic 中存入 message;GetChannel 则会获取已有的或者新建指定的 Channel,注意如果是新建的话就会往 channelUpdateChan 中发送一个信号;GenerateID 是为 message 生成一个唯一的id,因为我们暂时用不到这个 id,所以目前就简略处理了。

channel.go

type Channel struct {
	topicName string
	name      string
	memoryMsgChan chan *Message  //暂存发送到该channel下的message
}

// NewChannel creates a new instance of the Channel type and returns a pointer
func NewChannel(topicName string, channelName string) *Channel {
	return &Channel{
		topicName:      topicName,
		name:           channelName,
		memoryMsgChan:  make(chan *Message, 10000),
	}
}


// PutMessage writes a Message to the queue
func (c *Channel) PutMessage(m *Message) error {
	log.Printf("message 进入 channel,body:%s",m.Body)
	c.memoryMsgChan <- m
	return nil
}

Channel 结构体很简单,PutMessage 方法在上面 topic 的工作协程中被调用用于将 topic 中的 message 传递给该 channel,那么这些暂存在channel 中的 message 又是在什么时候发送给对应的消费者呢?我们接着往下看。

其他组件

程序入口(nsqd.go)

type NSQD struct {
	sync.RWMutex
	topicMap map[string]*Topic
}

func Start() (*NSQD, error) {
	var err error
	var tcpListener net.Listener
	tcpListener, err = net.Listen("tcp", "0.0.0.0:4150")
	if err != nil {
		return nil, fmt.Errorf("listen (%s) failed - %s", "0.0.0.0:4150", err)
	}
	log.Printf("TCP: listening on %s", tcpListener.Addr())

	n := &NSQD{
		topicMap:             make(map[string]*Topic),
	}
	tcpServer := &tcpServer{nsqd: n,tcpListener: tcpListener}
	go tcpServer.serve()
	return n, nil
}


// GetTopic performs a thread safe operation
// 没有就新建
func (n *NSQD) GetTopic(topicName string) *Topic {
	n.Lock()
	defer n.Unlock()
	t, ok := n.topicMap[topicName]
	if ok {
		return t
	}
	t = NewTopic(topicName)
	n.topicMap[topicName] = t
	return t
}

// channels returns a flat slice of all channels in all topics
func (n *NSQD) channels() []*Channel {
	var channels []*Channel
	n.RLock()
	for _, t := range n.topicMap {
		t.RLock()
		for _, c := range t.channelMap {
			channels = append(channels, c)
		}
		t.RUnlock()
	}
	n.RUnlock()
	return channels
}

NSQD 结构体存放了该消息队列中的所有topic;它的 GetTopic 方法用于根据给定的 topicName 取出对应的 topic,如果没有就新建;它的 channels 方法用于取出该消息队列中所有的channel。这两个方法都不难,我们主要看下程序的启动入口 Start 方法。 Start 方法首先开启了一个tcp 服务端监听4150端口,然后创建了一个 tcpServer 结构体,并且开启了一个新的协程来执行 tcpServer 的 serve 方法。

tcp_server.go

type tcpServer struct {
	nsqd  *NSQD
	tcpListener   net.Listener
}


func (tcpServer *tcpServer) serve () error {
	for {
		clientConn, err := tcpServer.tcpListener.Accept()
		if err != nil {
			break
		}
		//每个客户端来连接都起一条协程来处理
		go func() {
			log.Printf("TCP: new client(%s)", clientConn.RemoteAddr())

			prot := &protocolV2{nsqd: tcpServer.nsqd}

			client := prot.NewClient(clientConn)

			err := prot.IOLoop(client)
			if err != nil {
				log.Printf("client(%s) - %s", clientConn.RemoteAddr(), err)
			}
			client.Close()
		}()
	}

	return nil
}

可以看到 serve 方法就是在一个无限for循环中,不停的接收来自客户端的连接,并且为每一个成功建立连接的客户端创建一个工作协程来进行接下来的各种工作。具体的,在该工作协程中我们会创建一个 protocol 结构体和一个 client 结构体,然后调用 protocol 的 IOLoop 方法。

在看protocol.go 之前,我们先看一下client.go

client.go

const defaultBufferSize = 16 * 1024

type client struct {
	sync.Mutex

	// original connection
	net.Conn

	// reading/writing interfaces
	Reader *bufio.Reader
	Writer *bufio.Writer

	SubEventChan      chan *Channel //传递订阅事件


}

func newClient(conn net.Conn) *client {
	c := &client{
		Conn: conn,
		Reader: bufio.NewReaderSize(conn, defaultBufferSize),
		Writer: bufio.NewWriterSize(conn, defaultBufferSize),
		//这里有缓存是为了防止处理订阅事件的协程还没准备好订阅事件就来了导致订阅阻塞,
		//因为一个消费者只能订阅一次,所以这里容量为1
		SubEventChan:      make(chan *Channel, 1),
	}
	return c
}

func (c *client) Flush() error {
	return c.Writer.Flush()
}

可以看到 client 结构体中主要保存了一个 SubEventChan 用来传递订阅事件,同时因为使用了bufio 缓存,所以有一个Flush 方法用于强制消息的发送。

protocal.go

protocal.go 承担了包括从生产者中接收 message ,接受消费者的订阅以及向消费者发送 message,在每一次网络收发中根据协议编码以及解码消息等责任,可以说是最复杂的一个文件了,接下来我们慢慢解析。

//每个客户端都有一个对应的工作协程,负责接收来自客户端的消息,并进行实际处理
func (p *protocol) IOLoop(client *client) error {
        var err error
	var line []byte

	//另起一条协程处理消费者相关
	go p.messagePump(client)

	for {
		line, err = client.Reader.ReadSlice('\n')
		if err != nil {
			if err == io.EOF {
				err = nil
			} else {
				err = fmt.Errorf("failed to read command - %s", err)
			}
			break
		}

		// trim the '\n'
		line = line[:len(line)-1]
		// optionally trim the '\r'
		if len(line) > 0 && line[len(line)-1] == '\r' {
			line = line[:len(line)-1]
		}
		params := bytes.Split(line, separatorBytes)

		err = p.Exec(client, params)
		if err != nil {
			break
		}
	}
	return err

首先,从前面我们知道每个连接的客户端都有一个专门的协程来执行IOLoop 方法。然后,从代码中可以看出 IOLoop 方法主要就是在一个无限for循环中,不停的接收来自该客户端的消息并进行处理。此处我们暂时只处理 PUB 和 SUB 消息,前者表示客户端作为生产者向指定的topic 发送某个 message, 后者表示该客户端作为消费者订阅指定 topic 指定 channel。我们先看下这两种消息的格式

PUB
Publish a message to a topic:

PUB <topic_name>\n
[ 4-byte size in bytes ][ N-byte binary data ]

SUB
Subscribe to a topic/channel

SUB <topic_name> <channel_name>\n

然后我们具体看下 IOLoop 是如何解析来自客户端的消息的:首先读取消息直到 '\n',然后将读取的内容去除 '\n' 后按照空格进行切分得到 params ,在 Exec 方法中我们根据 params 的第一个元素的值来区分消息的类型并调用相应的方法进行处理,注意对于“PUB”类型的消息我们其实还没有读完,在PUB 方法中我们还会进行一次读取,具体参见下面分析。

func (p *protocol) PUB(client *client, params [][]byte)  error {
	var err error
	topicName := string(params[1])
	messageLen := make([]byte,4)
	_, err  = io.ReadFull(client.Reader, messageLen)
	if err != nil {
		return err
	}
	bodyLen:= int32(binary.BigEndian.Uint32(messageLen))
	messageBody := make([]byte, bodyLen)
	_, err = io.ReadFull(client.Reader, messageBody)
	if err != nil {
		return err
	}

	topic := p.nsqd.GetTopic(topicName)
	msg := NewMessage(topic.GenerateID(), messageBody)
	log.Printf("receive message from %s, topic:%s, message: %s",client.RemoteAddr(),topicName,string(messageBody))
	_ = topic.PutMessage(msg)
	return nil
}

在 PUB 方法中,首先我们可以从 params 中拿到 topicName ;然后我们进行第二次读取来拿 message,首先读取4个字节拿到 message 的大小 bodyLen,然后再读取 bodyLen 大小的内容得到 message,然后我们会调用 GetTopic 方法获取指定 topic(注意 GetTopic 方法在找不到指定topic 的时候会新建一个),最后会将该 message 存放到 topic 中。

func (p *protocol) SUB(client *client, params [][]byte)  error {
	topicName := string(params[1])
	channelName := string(params[2])

	var channel *Channel
	topic := p.nsqd.GetTopic(topicName)
	channel = topic.GetChannel(channelName)
	// update message pump
	client.SubEventChan <- channel

	return nil
}

SUB 方法同样会先获取指定的 topic 和指定的 channel,然后将该 channel 发送到 client 的 SubEventChan 中(这里其实就是 channel 和 client 绑定的地方),表示有订阅事件发生。

注意在上面的 IOLoop 方法中,我们一开始就会新建一个工作协程调用 messagePump 方法用来处理消费者相关。接下来我们仔细看看这个方法。

func (p *protocol) messagePump(client *client) {
	var err error
	var memoryMsgChan chan *Message
	var subChannel *Channel
	//这里新创建subEventChan是为了在下面可以把它置为nil以实现“一个客户端只能订阅一次”的目的
	subEventChan := client.SubEventChan

	for {
		select {
		case subChannel = <-subEventChan:  //表示有订阅事件发生,这里的subChannel就是消费者实际绑定的channel
			log.Printf("topic:%s channel:%s 发生订阅事件",subChannel.topicName,subChannel.name)
			memoryMsgChan = subChannel.memoryMsgChan
			// you can't SUB anymore
			subEventChan = nil
		case msg := <-memoryMsgChan: //如果channel对应的内存通道有消息的话
			err = p.SendMessage(client, msg)
			if err != nil {
                                go func() {
					_ = subChannel.PutMessage(msg)
				}()
				log.Printf("PROTOCOL(V2): [%s] messagePump error - %s", client.RemoteAddr(), err)
				goto exit
			}
		}
	}

exit:
	log.Printf("PROTOCOL(V2): [%s] exiting messagePump", client.RemoteAddr())
}

我们首先会监听来自 subEventChan 的订阅消息(就是上面的SUB方法传递过来的);客户端订阅之后,我们会监听客户端绑定的channel,如果有message 的话,就会发送给该客户端。此处如果发送失败的话,我们目前的解决办法就是把 message 重新放入到 channel 中,等待再次发送。

接下来我们再看看具体负责消息发送的相关方法

func (p *protocol) SendMessage(client *client, msg *Message) error {
	log.Printf("PROTOCOL(V2): writing to client(%s) - message: %s", client.RemoteAddr(), msg.Body)

	msgByte, err := msg.Bytes()
	if err != nil {
		return err
	}
	return p.Send(client, msgByte)
}

func (p *protocol) Send(client *client,data []byte) error {
	client.Lock()
	defer client.Unlock()
	_, err := SendFramedResponse(client.Writer, data)
	if err != nil {
		return err
	}
	//因为client.Writer使用了bufio缓存,所以这里我们就先暂时强制刷新
	err = client.Flush()
	return err
}

// 进行实际发送,并且会在消息前面加上4bytes的消息长度
func SendFramedResponse(w io.Writer, data []byte) (int, error) {
	beBuf := make([]byte, 4)
	size := uint32(len(data))

	binary.BigEndian.PutUint32(beBuf, size)
	n, err := w.Write(beBuf)
	if err != nil {
		return n, err
	}

	n, err = w.Write(data)
	return n + 4, err
}

SendMessage 负责将 Message 发送给消费者,它最终会调用 SendFramedResponse 在最终的消息前面加上4bytes的消息长度并进行发送。

测试

测试代码分析

apps/nsqd/main.go

func main() {
	log.SetFlags(log.Lshortfile | log.Ltime)
	_, err := nsqd.Start()
	if err != nil {
		log.Printf("failed to instantiate nsqd - %s", err)
	}
	select {
	}
}

nsq 的启动程序很简单,直接调用 nsqd 的 Start 即可,此处为了防止程序的退出,我们暂时就用 select 阻塞住。

apps/client/client.go

func main() {
	nsqdAddr := "127.0.0.1:4150"
	conn, err := net.Dial("tcp", nsqdAddr)
	go readFully(conn)
	if err != nil {
		log.Fatal(err)
	}
	cmd := Publish("mytopic", []byte("ha ha"))
	cmd.WriteTo(conn)

	cmd = Subscribe("mytopic", "mychannel")
	cmd.WriteTo(conn)

	select {

	}
}

func readFully(conn net.Conn) {
	len:=make([]byte, 4)
	for {
		_, err := conn.Read(len)
		if err != nil {
			fmt.Printf("error during read: %s", err)
		}
		size :=binary.BigEndian.Uint32(len)
		data := make([]byte, size)
		var n int
		n, err = conn.Read(data)
		if err != nil {
			fmt.Printf("error during read: %s", err)
		}
		fmt.Printf("receive: <%s> ,size:%d\n", data[16:n],n)
	}
}

对于测试的 client 来说,首先我们去连接本地的4150端口,然后发送一次 PUB 消息,发送一次 SUB 消息,同时我们会新启一个协程不断的接收服务端的消息。读取消息的时候,首先读取4个字节的长度,然后忽略 message ID,直接打印出消息体。发送消息的时候我们用到了 command.go 来编码消息,该文件并不复杂,大家可自行查看。

测试结果

我们先启动服务端,再启动客户端,最终会看到如下结果

服务端

Snipaste_2021-05-09_13-39-17.png

客户端

Snipaste_2021-05-09_13-39-38.png

代码地址

git clone https://github.com/xianxueniao150/mini-nsq.git
git checkout day01