gopacket tcpassembly源码分析,成功拿到offer

90 阅读4分钟

调用

参考示例example/httpassembly

  1. 自定义一个factory,实现New接口
type httpStream struct {
	net, transport gopacket.Flow
	r              tcpreader.ReaderStream
}
func (h *httpStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
}

  1. New接口保存一个tcpreader.NewReaderStream()流,启动处理流的协程,然后返回这个流

func (h *httpStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
	hstream := &httpStream{
		net:       net,
		transport: transport,
		r:         tcpreader.NewReaderStream(),
	}
	go hstream.run() // Important... we must guarantee that data from the reader stream is read.

	// ReaderStream implements tcpassembly.Stream, so we can return a pointer to it.
	return &hstream.r
}

  1. 流处理协程,建一个buf,从这个buf中读取数据,然后重组解析
func (h *httpStream) run() {
	buf := bufio.NewReader(&h.r)
	for {
		req, err := http.ReadRequest(buf)
		if err == io.EOF {
			// We must read until we see an EOF... very important!
			return
		} else if err != nil {
			log.Println("Error reading stream", h.net, h.transport, ":", err)
		} else {
			bodyBytes := tcpreader.DiscardBytesToEOF(req.Body)
			req.Body.Close()
			log.Println("Received request from stream", h.net, h.transport, ":", req, "with", bodyBytes, "bytes in request body")
		}
	}
}

  1. 使用,和reassembly一样
func main() {
	defer util.Run()()
	// 1. 打开设备
	var handle *pcap.Handle
	var err error
	handle, err = pcap.OpenLive(*iface, int32(*snaplen), true, pcap.BlockForever)
	if err != nil {
		log.Fatal(err)
	}
	// 设置BPF
	if err := handle.SetBPFFilter(*filter); err != nil {
		log.Fatal(err)
	}

	// 2. 初始化assembly
	streamFactory := &httpStreamFactory{}
	streamPool := tcpassembly.NewStreamPool(streamFactory)
	assembler := tcpassembly.NewAssembler(streamPool)

	log.Println("reading in packets")
	// 3.初始化packetSource
	packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
	packets := packetSource.Packets()
	ticker := time.Tick(time.Second)
	for {
		select {
		// 4. 读取包
		case packet := <-packets:
			// A nil packet indicates the end of a pcap file.
			if packet == nil {
				return
			}
			if *logAllPackets {
				log.Println(packet)
			}
			if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
				log.Println("Unusable packet")
				continue
			}
			tcp := packet.TransportLayer().(*layers.TCP)
			// 5. tcp直接丢进去
			assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, packet.Metadata().Timestamp)

		case <-ticker:
			// 6. 定时书信连接
			assembler.FlushOlderThan(time.Now().Add(time.Minute * -2))
		}
	}
}

Assembler

type AssemblerOptions struct {
	// 等待无序包时要缓冲的page总数最大值
	// 一旦达到这个上限值, Assembler将会降级刷新每个连接的,如果<=0将被忽略。
	MaxBufferedPagesTotal int
	// 单个连接缓冲的page最大值
	// 如果达到上限,则将刷新最小序列号以及任何连续数据。如果<= 0,这将被忽略。
	MaxBufferedPagesPerConnection int
}

type Assembler struct {
	AssemblerOptions	// 选项
	ret      []Reassembly	// 数据包
	pc       *pageCache
	connPool *StreamPool
}

// 创建一个Assember
// pool: StreamPool,来Assember共享
// DefaultAssemblerOptions
// 改造建议 - 选项模式与默认值
func NewAssembler(pool *StreamPool) *Assembler {
	pool.mu.Lock()
	pool.users++
	pool.mu.Unlock()
	return &Assembler{
		ret:              make([]Reassembly, assemblerReturnValueInitialSize),
		pc:               newPageCache(),
		connPool:         pool,
		AssemblerOptions: DefaultAssemblerOptions, //默认值,无限制
	}
}

AssemblyTimestamp

func (a *Assembler) AssembleWithTimestamp(netFlow gopacket.Flow, t *layers.TCP, timestamp time.Time) {
	// 忽略空的数据包,比如keepalived
	// tcp握手时 t.SYN = 1 t.FIN = 0 t.RST = 0 len(t.LayerPayload()) == 0 
	// 即false && true && true && true
	// tcp挥手时 t.SYN = 0 t.FIN = 1 t.RST = 0 len(t.LayerPayload()) == 0 
	// 即true && false && true && true
	if !t.SYN && !t.FIN && !t.RST && len(t.LayerPayload()) == 0 {
		if *debugLog {
			log.Println("ignoring useless packet")
		}
		return
	}

	a.ret = a.ret[:0]
	// 4元组组成的key
	key := key{netFlow, t.TransportFlow()}

	var conn *connection
	// This for loop handles a race condition where a connection will close, lock
	// the connection pool, and remove itself, but before it locked the connection
	// pool it's returned to another Assemble statement.  This should loop 0-1
	// times for the VAST majority of cases.
	// 创建conn
	for {
		// tcp keepalive syn=0 payload=0
		// 即 true && true end 为true?
		conn = a.connPool.getConnection(key, !t.SYN && len(t.LayerPayload()) == 0, timestamp)
		if conn == nil {
			if *debugLog {
				log.Printf("%v got empty packet on otherwise empty connection", key)
			}
			return
		}
		conn.mu.Lock()
		if !conn.closed {
			break
		}
		conn.mu.Unlock()
	}
	if conn.lastSeen.Before(timestamp) {
		conn.lastSeen = timestamp
	}
	//type Sequence int64 提供Difference和Add函数
	seq, bytes := Sequence(t.Seq), t.Payload // seq:当前序号  bytes:tcp负载的数据
	// 校验序号
	if conn.nextSeq == invalidSequence {
		if t.SYN {
			if *debugLog {
				log.Printf("%v saw first SYN packet, returning immediately, seq=%v", key, seq)
			}
			// 添加 Reassembly重组后的对象
			a.ret = append(a.ret, Reassembly{
				Bytes: bytes,
				Skip:  0,
				Start: true,
				Seen:  timestamp,
			})
			// 下一个包的序号 = 当前的序号 + 字节数 + 1
			conn.nextSeq = seq.Add(len(bytes) + 1)
		} else {
			if *debugLog {
				log.Printf("%v waiting for start, storing into connection", key)
			}
			// 插入到数据到connection中
			a.insertIntoConn(t, conn, timestamp)
		}
	} else if diff := conn.nextSeq.Difference(seq); diff > 0 {
		if *debugLog {
			log.Printf("%v gap in sequence numbers (%v, %v) diff %v, storing into connection", key, conn.nextSeq, seq, diff)
		}
		// 插入到数据到connection中
		a.insertIntoConn(t, conn, timestamp)
	} else {=<0
		// 字节校准
		bytes, conn.nextSeq = byteSpan(conn.nextSeq, seq, bytes)
		if *debugLog {
			log.Printf("%v found contiguous data (%v, %v), returning immediately", key, seq, conn.nextSeq)
		}
		a.ret = append(a.ret, Reassembly{
			Bytes: bytes,
			Skip:  0,
			End:   t.RST || t.FIN,
			Seen:  timestamp,
		})
	}
	if len(a.ret) > 0 {
		a.sendToConnection(conn)
	}
	conn.mu.Unlock()
}

insertIntoConn

func (a *Assembler) insertIntoConn(t *layers.TCP, conn *connection, ts time.Time) {
	if conn.first != nil && conn.first.seq == conn.nextSeq {
		panic("wtf")
	}
	// p:第一页 p2:最后一页 numPages:页数
	p, p2, numPages := a.pagesFromTCP(t, ts)

	//遍历双向链接page列表获取正确的放置给定序号的位置
	// 直接插入不好吗?
	prev, current := conn.traverseConn(Sequence(t.Seq))
	conn.pushBetween(prev, current, p, p2)
	conn.pages += numPages

	// 校验最大缓冲page数
	if (a.MaxBufferedPagesPerConnection > 0 && conn.pages >= a.MaxBufferedPagesPerConnection) ||
		(a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) {
		if *debugLog {
			log.Printf("%v hit max buffer size: %+v, %v, %v", conn.key, a.AssemblerOptions, conn.pages, a.pc.used)
		}
		// 弹出
		a.addNextFromConn(conn)
	}
}

pagesFromTCP

从TCP数据包创建一个page(或设置一个pages)。

注意此函数不应该接受SYN包,因为它不能正确处理seq。

返回双连接的page列表中的第一个和最后一个页面。

func (a *Assembler) pagesFromTCP(t *layers.TCP, ts time.Time) (p, p2 *page, numPages int) {
	first := a.pc.next(ts)
	current := first
	numPages++
	seq, bytes := Sequence(t.Seq), t.Payload
	for {
		length := min(len(bytes), pageBytes)
		// 拷贝负载数据
		current.Bytes = current.buf[:length]
		copy(current.Bytes, bytes)
		// 设置seq
		current.seq = seq
		// 处理剩余数据>1900,一般不会进入到这里,实际场景下MTU会将TCP切段
		bytes = bytes[length:]
		if len(bytes) == 0 {
			break
		}
		seq = seq.Add(length)
		// 创建下一页
		current.next = a.pc.next(ts)
		// 设置下一个的prev为current
		current.next.prev = current
		// 设置下一页
		current = current.next
		numPages++
	}
	current.End = t.RST || t.FIN // 设置end
	return first, current, numPages
}

addNextFromConn

弹出第一页

func (a *Assembler) addNextFromConn(conn *connection) {
	if conn.nextSeq == invalidSequence {
		conn.first.Skip = -1
	} else if diff := conn.nextSeq.Difference(conn.first.seq); diff > 0 {
		conn.first.Skip = int(diff)
	}
	conn.first.Bytes, conn.nextSeq = byteSpan(conn.nextSeq, conn.first.seq, conn.first.Bytes)
	if *debugLog {
		log.Printf("%v   adding from conn (%v, %v)", conn.key, conn.first.seq, conn.nextSeq)
	}
	a.ret = append(a.ret, conn.first.Reassembly)
	a.pc.replace(conn.first)
	if conn.first == conn.last {
		conn.first = nil
		conn.last = nil
	} else {
		conn.first = conn.first.next
		conn.first.prev = nil
	}
	conn.pages--
}

sendToConnection

func (a *Assembler) sendToConnection(conn *connection) {
	// 组数据
	a.addContiguous(conn)
	if conn.stream == nil {
		panic("why?")
	}
	conn.stream.Reassembled(a.ret)
	if a.ret[len(a.ret)-1].End {
		a.closeConnection(conn)
	}
}

addContiguous

func (a *Assembler) addContiguous(conn *connection) {
	for conn.first != nil && conn.nextSeq.Difference(conn.first.seq) <= 0 {
		a.addNextFromConn(conn)
	}
}

addNextFromConn

弹出第一页添加到数组中

func (a *Assembler) addNextFromConn(conn *connection) {
	if conn.nextSeq == invalidSequence {
		conn.first.Skip = -1
	} else if diff := conn.nextSeq.Difference(conn.first.seq); diff > 0 {
		conn.first.Skip = int(diff)
	}
	conn.first.Bytes, conn.nextSeq = byteSpan(conn.nextSeq, conn.first.seq, conn.first.Bytes)
	if *debugLog {
		log.Printf("%v   adding from conn (%v, %v)", conn.key, conn.first.seq, conn.nextSeq)
	}
	a.ret = append(a.ret, conn.first.Reassembly)
	a.pc.replace(conn.first)
	if conn.first == conn.last {
		conn.first = nil
		conn.last = nil
	} else {
		conn.first = conn.first.next
		conn.first.prev = nil
	}
	conn.pages--
}

closeConnection

func (a *Assembler) closeConnection(conn *connection) {
	if *debugLog {
		log.Printf("%v closing", conn.key)
	}
	conn.stream.ReassemblyComplete()
	conn.closed = true
	a.connPool.remove(conn)
	for p := conn.first; p != nil; p = p.next {
		a.pc.replace(p)
	}
}

StreamPool

管理流的连接池,初始连接池分配1024个

type StreamPool struct {
	conns              map[key]*connection
	users              int
	mu                 sync.RWMutex
	factory            StreamFactory
	free               []*connection
	all                [][]connection
	nextAlloc          int
	newConnectionCount int64
}

func NewStreamPool(factory StreamFactory) *StreamPool {
	return &StreamPool{
		conns:     make(map[key]*connection, initialAllocSize),
		free:      make([]*connection, 0, initialAllocSize),
		factory:   factory,
		nextAlloc: initialAllocSize,
	}
}

grow

分配连接

func (p *StreamPool) grow() {
	conns := make([]connection, p.nextAlloc)
	p.all = append(p.all, conns)
	for i := range conns {
		p.free = append(p.free, &conns[i])
	}
	if *memLog {
		log.Println("StreamPool: created", p.nextAlloc, "new connections")
	}
	p.nextAlloc *= 2
}

newConnection

创建连接

func (p *StreamPool) newConnection(k key, s Stream, ts time.Time) (c *connection) {
	if *memLog {
		p.newConnectionCount++
		if p.newConnectionCount&0x7FFF == 0 {
			log.Println("StreamPool:", p.newConnectionCount, "requests,", len(p.conns), "used,", len(p.free), "free")
		}
	}
	if len(p.free) == 0 {
		p.grow()
	}
	index := len(p.free) - 1
	c, p.free = p.free[index], p.free[:index]
	c.reset(k, s, ts)
	return c
}

getConnection

// 返回一个连接,如果连接已经被关闭或者连接不存在,返回nil
func (p *StreamPool) getConnection(k key, end bool, ts time.Time) *connection {
	p.mu.RLock()
	conn := p.conns[k]
	p.mu.RUnlock()
	if end || conn != nil {
		return conn
	}
	s := p.factory.New(k[0], k[1])
	p.mu.Lock()
	conn = p.newConnection(k, s, ts)
	if conn2 := p.conns[k]; conn2 != nil {
		p.mu.Unlock()
		return conn2
	}
	p.conns[k] = conn
	p.mu.Unlock()
	return conn
}

remove

删除某个个连接

func (p *StreamPool) remove(conn *connection) {
	p.mu.Lock()
	delete(p.conns, conn.key)
	p.free = append(p.free, conn)
	p.mu.Unlock()
}


connection

返回所有的连接

func (p *StreamPool) connections() []*connection {
	p.mu.RLock()
	conns := make([]*connection, 0, len(p.conns))
	for _, conn := range p.conns {
		conns = append(conns, conn)
	}
	p.mu.RUnlock()
	return conns
}

connection

type connection struct {
	key               key
	pages             int
	first, last       *page
	nextSeq           Sequence
	created, lastSeen time.Time


![img](https://p9-xtjj-sign.byteimg.com/tos-cn-i-73owjymdk6/38e6dce61a40497b83c537de9a272b68~tplv-73owjymdk6-jj-mark-v1:0:0:0:0:5o6Y6YeR5oqA5pyv56S-5Yy6IEAg5py65Zmo5a2m5Lmg5LmL5b-DQUk=:q75.awebp?rk3s=f64ab15b&x-expires=1771252490&x-signature=%2BKO2sZelbUJHtCgpPZHER0HDPqw%3D)
![img](https://p9-xtjj-sign.byteimg.com/tos-cn-i-73owjymdk6/e3188ea961344809956a9c55f52277b5~tplv-73owjymdk6-jj-mark-v1:0:0:0:0:5o6Y6YeR5oqA5pyv56S-5Yy6IEAg5py65Zmo5a2m5Lmg5LmL5b-DQUk=:q75.awebp?rk3s=f64ab15b&x-expires=1771252490&x-signature=HKFNW1fq8TwANRWzoW7z6G4uTxE%3D)

**网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。**

**[需要这份系统化的资料的朋友,可以添加戳这里获取](https://gitee.com/vip204888)**


**一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!**