Memsniff代码阅读

346 阅读3分钟

Memsniff

官方介绍

blog.box.com/introducing…

代码

github.com/box/memsnif…

协议解析部分-状态机模式

状态机是一个抽象机器,具有两个主要部分:状态和转换。状态是指一个系统的当前状态。一个状态机在任意时间点只会有一个激活状态。转换是指从当前状态到一个新状态的切换。在一个转换发生之前或之后通常会执行一个或多个动作。 我个人理解,状态机有一个前提是必须有一个源源不断的外部输入(交互)来触发状态机的改变,比如在Memsniff中,这个源源不断的输入就是抓到的包中的字节流。

延伸阅读

Memcache协议 github.com/memcached/m… github.com/memcached/m…

代码

状态机的初始化
func NewFsm(logger log.Logger) model.Fsm {
	fsm := &fsm{
		logger: logger,
	}
    //初始化的状态就是解析第一个字节
	fsm.state = fsm.peekBinaryProtocolMagicByte
	return fsm
}
状态机的运行
func (f *fsm) Run() {
	for {
		err := f.state()
		switch err {
		case nil:
			continue
		case reader.ErrShortRead, io.EOF:
			return
		default:
			// data lost or protocol error, try to resync at the next command
			f.log(2, "trying to resync after error:", err)
			f.consumer.ClientReader.Reset()
			f.consumer.ServerReader.Reset()
			f.state = f.readCommand
			return
		}
	}
}

解析第一个字节,主要是区分Memcache的协议,Memsniff目前还没支持二进制的协议

func (f *fsm) peekBinaryProtocolMagicByte() error {
	f.consumer.ServerReader.Truncate()
	firstByte, err := f.consumer.ClientReader.PeekN(1)
	if err != nil {
		if _, ok := err.(reader.ErrLostData); ok {
			// try again, making sure we read from the start of a client packet.
			f.consumer.ClientReader.Truncate()
			err = reader.ErrShortRead
		}
		return err
	}
	if firstByte[0] == 0x80 {
		//binary memcached protocol, don't try to handle this connection
		f.log(2, "looks like binary protocol, ignoring connection")
		f.consumer.Close()
		return io.EOF
	}
    //状态变为readCommand
	f.state = f.readCommand
	return nil
}

readCommand主要是判断当前的请求是什么命令,然后状态机就转换为响应的状态,比如请求命令是"get"、"gets"状态就转换为f.handleGet,如果请求命令是"set", "add", "replace", "append", "prepend", "cas"状态就转换为f.handleSet

func (f *fsm) readCommand() error {
	f.args = f.args[:0]
	f.consumer.ServerReader.Truncate()
	f.log(3, "reading command")
	pos, err := f.consumer.ClientReader.IndexAny(" \n")
	if err != nil {
		return err
	}

	cmd, err := f.consumer.ClientReader.ReadN(pos + 1)
	if err != nil {
		return err
	}
	f.cmd = string(bytes.TrimRight(cmd, " \r\n"))
	f.log(3, "read command:", f.cmd)

	if !asciiRe.MatchString(f.cmd) {
		return errProtocolDesync
	}

	if f.commandState() != nil {
		f.state = f.readArgs
		return nil
	}

	f.state = f.handleUnknown
	return nil
}

// dispatchCommand is the state after the complete client request has been read.
func (f *fsm) commandState() state {
	switch f.cmd {
	case "get", "gets":
		return f.handleGet
	case "set", "add", "replace", "append", "prepend", "cas":
		return f.handleSet
	case "quit":
		return f.handleQuit
	default:
		return nil
	}
}

func (f *fsm) readArgs() error {
	f.consumer.ServerReader.Truncate()
	pos, err := f.consumer.ClientReader.IndexAny(" \n")
	if err != nil {
		return err
	}
	word, err := f.consumer.ClientReader.ReadN(pos + 1)
	if err != nil {
		return err
	}
	f.args = append(f.args, string(bytes.TrimRight(word[:len(word)-1], "\r")))
	delim := word[len(word)-1]
	if delim == ' ' {
		return nil
	}
	f.log(3, "read arguments:", f.args)
	f.state = f.commandState()
	return nil
}

以handleSet为例,handleSet通过f.args[3]解析出size,然后把状态切换到discardResponse

func (f *fsm) handleSet() error {
	if len(f.args) < 4 {
		return f.discardResponse()
	}
	size, err := strconv.Atoi(f.args[3])
	if err != nil {
		return f.discardResponse()
	}
	f.log(3, "discarding", size+len(crlf), "from client")
	_, err = f.consumer.ClientReader.Discard(size + len(crlf))
	if err != nil {
		return err
	}
	f.log(3, "discarding response from server")
	return f.discardResponse()
}

discardResponse主要负责处理字节流中无用的字节,然后把状态切换为readCommand

func (f *fsm) discardResponse() error {
	f.state = f.discardResponse
	f.log(3, "discarding response from server")
	line, err := f.consumer.ServerReader.ReadLine()
	if err != nil {
		return err
	}
	f.log(3, "discarded response from server:", string(line))
	f.state = f.readCommand
	return nil
}

聚合函数-工厂模式

代码

定义统一的聚合函数接口
type Aggregator interface {
	// Add records a single data point.
	Add(n int64)
	// Result returns the final output of aggregation.
	Result() int64
	// Reset returns the aggregator to its initial state.
	Reset()
}

已Max和Min聚合函数为例

// Max retains the maximum value in the aggregated data.
type Max struct {
	max       int64
	seenFirst bool
}

func (m *Max) Add(n int64) {
	if !m.seenFirst {
		m.max = n
		m.seenFirst = true
		return
	}
	if m.max < n {
		m.max = n
	}
}

func (m *Max) Result() int64 {
	return m.max
}

func (m *Max) Reset() {
	m.seenFirst = false
	m.max = 0
}

// Min retains the minimum value in the aggregated data.
type Min struct {
	min       int64
	seenFirst bool
}

func (m *Min) Add(n int64) {
	if !m.seenFirst {
		m.min = n
		m.seenFirst = true
		return
	}
	if m.min > n {
		m.min = n
	}
}

func (m *Min) Result() int64 {
	return m.min
}

func (m *Min) Reset() {
	m.seenFirst = false
	m.min = 0
}
工厂的创建
// NewKeyAggregatorFactory creates a KeyAggregatorFactory.  The descriptor should be a
// comma-separated list of field names (key, size, etc.) and aggregate descriptions
// (sum(size), p99(latency), etc.).
func NewKeyAggregatorFactory(desc string) (KeyAggregatorFactory, error) {
	fieldDescs := strings.Split(desc, ",")

	var kaf KeyAggregatorFactory
	for _, field := range fieldDescs {
		field = strings.TrimSpace(field)

		fieldID, aggDesc, err := parseField(field)
		if err != nil {
			return KeyAggregatorFactory{}, err
		}
		if aggDesc == "" {
			// simple field
			kaf.KeyFields = append(kaf.KeyFields, field)
			kaf.keyFieldMask |= fieldID
		} else {
			// can aggregate integer fields only
			if fieldID&model.IntFields == 0 {
				return KeyAggregatorFactory{}, BadDescriptorError(field)
			}
			aggFactory, err := NewFactoryFromDescriptor(aggDesc)
			if err != nil {
				return KeyAggregatorFactory{}, err
			}

			kaf.AggFields = append(kaf.AggFields, field)
			kaf.aggFieldIDs = append(kaf.aggFieldIDs, fieldID)
			kaf.aggFactories = append(kaf.aggFactories, aggFactory)
		}
	}

	return kaf, nil
}
// NewFactoryFromDescriptor returns an AggregatorFactory that will create
// Aggregators based on desc.  Returns BadDescriptorError if desc is not a valid descriptor.
func NewFactoryFromDescriptor(desc string) (AggregatorFactory, error) {
	switch desc {
	case "max":
		return func() Aggregator { return &Max{} }, nil

	case "min":
		return func() Aggregator { return &Min{} }, nil

	case "avg":
		return func() Aggregator { return &Mean{} }, nil

	case "sum":
		return func() Aggregator { return &Sum{} }, nil

	default:
		if len(desc) >= 3 && desc[0] == 'p' {
			return percentileFactoryFromDescriptor(desc)
		}
		return nil, BadDescriptorError(desc)
	}
}
聚合函数的使用
// KeyAggregator tracks data across all requested event fields for a single key.
type KeyAggregator struct {
	// Key is the list of key fields over which we are aggregating.
	Key []string

	// aggFieldIDs is the list of event fields whose values we take for aggregation,
	// in the same order as aggs and as the descriptor string provided to the
	// KeyAggregatorFactory.
	aggFieldIDs []model.EventFieldMask
	// aggs is the actual aggregators, in the same order as the descriptor string.
	aggs []Aggregator
}

// Add updates all aggregators tracked for this key according to the provided event.
func (ka KeyAggregator) Add(e model.Event) {
	for i := range ka.aggs {
		ka.aggs[i].Add(fieldAsInt64(e, ka.aggFieldIDs[i]))
	}
}

// Result returns the aggregation results for this key, in order of their appearance
// in the descriptor used to create the KeyAggregatorFactory.
func (ka KeyAggregator) Result() []int64 {
	res := make([]int64, len(ka.aggs))
	for i := range ka.aggs {
        //在这里直接调用
		res[i] = ka.aggs[i].Result()
	}
	return res
}

// Reset clears all aggregators to their initial state.
func (ka *KeyAggregator) Reset() {
	ka.Key = nil
	for _, agg := range ka.aggs {
		agg.Reset()
	}
}