使用golang开发MySQL-binlog同步工具demo

259 阅读4分钟

背景

这篇是一个使用golang开发的binlog解析工具,更偏向demo和研究性质。简单来说,就是模拟MySQL binlog协议,开发一个服务,作为MySQL的“从库”,获取binlog,有点像java开发的canal。

实践

过程和结构

执行过程主要是server模块。首先连接MySQL,这里参考了我们使用的中间件部分(kingshard)。然后先关闭checkSum,然后作为从库注册到主库,发送binlog_dump命令。最后的操作就是监听获取binlog,然后通过go-mysql提供的方法,将binlog events 解析出来并打印。

代码

  1. config 配置部分,描述binlog文件,位置,主库MySQL账号信息等。
package app

type Config struct {
	Host string
	Port int
	User string
	Pass string
	ServerId int

	LogFile string
	Position int
}

  1. server模块 整个的核心部分,包括连接,注册,发送命令,获取binlog都是在这里。这里的解析binlog使用了go-mysql
package app

import (
	"bufio"
	"bytes"
	"context"
	"crypto/sha1"
	"encoding/binary"
	"errors"
	"fmt"
	"github.com/siddontang/go-mysql/replication"
	"io"
	"net"
	"os"
	"time"
)

const (
	MinProtocolVersion byte = 10

	OK_HEADER          byte = 0x00
	ERR_HEADER         byte = 0xff
	EOF_HEADER         byte = 0xfe
	LocalInFile_HEADER byte = 0xfb
)

const MaxPayloadLength = 1<<24 - 1

type Server struct {
	Cfg          *Config
	Ctx          context.Context
	conn         net.Conn
	io           *PacketIo
	registerSucc bool
}

func (s *Server) Run() {
	defer func() {
		s.Quit()
	}()

	s.dump()
}

func (s *Server) dump() {
	err := s.handshake()
	if err != nil {
		panic(err)
	}
	s.invalidChecksum()
	fmt.Println("dump ...")
	s.register()
	s.writeDumpCommand()
	parser := replication.NewBinlogParser()
	for {
		//time.Sleep(2 * time.Second)
		//s.query("select 1")

		data, err := s.io.readPacket()
		if err != nil || len(data) == 0 {
			continue
		}

		//s.Quit()

		if data[0] == OK_HEADER {
			//skip ok
			data = data[1:]
			if e, err := parser.Parse(data); err == nil {
				e.Dump(os.Stdout)
			} else {
				fmt.Println(err)
			}
		} else {
			s.io.HandleError(data)
		}
	}
}

func (s *Server) invalidChecksum()  {
	sql := `SET @master_binlog_checksum='NONE'`
	if err := s.query(sql); err != nil{
		fmt.Println(err)
	}
	//must read from tcp connection , either will be blocked
	_, _ = s.io.readPacket()
}

func (s *Server) handshake() error {
	conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", s.Cfg.Host, s.Cfg.Port), 10*time.Second)
	if err != nil {
		return err
	}

	tc := conn.(*net.TCPConn)
	tc.SetKeepAlive(true)
	tc.SetNoDelay(true)
	s.conn = tc

	s.io = &PacketIo{}
	s.io.r = bufio.NewReaderSize(s.conn, 16*1024)
	s.io.w = tc

	data, err := s.io.readPacket()
	if err != nil {
		return err
	}

	if data[0] == ERR_HEADER {
		return errors.New("error packet")
	}

	if data[0] < MinProtocolVersion {
		return fmt.Errorf("version is too lower, current:%d", data[0])
	}

	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
	connId := uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
	pos += 4
	salt := data[pos : pos+8]

	pos += 8 + 1
	capability := uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))

	pos += 2

	var status uint16
	var pluginName string
	if len(data) > pos {
		//skip charset
		pos++
		status = binary.LittleEndian.Uint16(data[pos : pos+2])
		pos += 2
		capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | capability
		pos += 2

		pos += 10 + 1
		salt = append(salt, data[pos:pos+12]...)
		pos += 13

		if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
			pluginName = string(data[pos : pos+end])
		} else {
			pluginName = string(data[pos:])
		}
	}

	fmt.Printf("conn_id:%v, status:%d, plugin:%v\n", connId, status, pluginName)

	//write
	capability = 500357
	length := 4 + 4 + 1 + 23
	length += len(s.Cfg.User) + 1

	pass := []byte(s.Cfg.Pass)
	auth := calPassword(salt[:20], pass)
	length += 1 + len(auth)
	data = make([]byte, length+4)

	data[4] = byte(capability)
	data[5] = byte(capability >> 8)
	data[6] = byte(capability >> 16)
	data[7] = byte(capability >> 24)

	//utf8
	data[12] = byte(33)
	pos = 13 + 23
	if len(s.Cfg.User) > 0 {
		pos += copy(data[pos:], s.Cfg.User)
	}

	pos++
	data[pos] = byte(len(auth))
	pos += 1 + copy(data[pos+1:], auth)

	err = s.io.writePacket(data)
	if err != nil {
		return fmt.Errorf("write auth packet error")
	}

	pk, err := s.io.readPacket()
	if err != nil {
		return err
	}

	if pk[0] == OK_HEADER {
		fmt.Println("handshake ok ")
		return nil
	} else if pk[0] == ERR_HEADER {
		s.io.HandleError(pk)
		return errors.New("handshake error ")
	}

	return nil
}

func (s *Server) writeDumpCommand() {
	s.io.seq = 0
	data := make([]byte, 4+1+4+2+4+len(s.Cfg.LogFile))
	pos := 4
	data[pos] = 18 //dump binlog
	pos++
	binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.Position))
	pos += 4

	//dump command flag
	binary.LittleEndian.PutUint16(data[pos:], 0)
	pos += 2

	binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
	pos += 4

	copy(data[pos:], s.Cfg.LogFile)

	s.io.writePacket(data)
	//ok
	res, _ := s.io.readPacket()
	if res[0] == OK_HEADER {
		fmt.Println("send dump command return ok.")
	} else {
		s.io.HandleError(res)
	}
}

func (s *Server) register() {
	s.io.seq = 0
	hostname, _ := os.Hostname()
	data := make([]byte, 4+1+4+1+len(hostname)+1+len(s.Cfg.User)+1+len(s.Cfg.Pass)+2+4+4)
	pos := 4
	data[pos] = 21 //register slave  command
	pos++
	binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
	pos += 4

	data[pos] = uint8(len(hostname))
	pos++
	n := copy(data[pos:], hostname)
	pos += n

	data[pos] = uint8(len(s.Cfg.User))
	pos++
	n = copy(data[pos:], s.Cfg.User)
	pos += n

	data[pos] = uint8(len(s.Cfg.Pass))
	pos++
	n = copy(data[pos:], s.Cfg.Pass)
	pos += n

	binary.LittleEndian.PutUint16(data[pos:], uint16(s.Cfg.Port))
	pos += 2

	binary.LittleEndian.PutUint32(data[pos:], 0)
	pos += 4

	//master id = 0
	binary.LittleEndian.PutUint32(data[pos:], 0)

	s.io.writePacket(data)

	//ok
	res, _ := s.io.readPacket()
	if res[0] == OK_HEADER {
		fmt.Println("register success.")
		s.registerSucc = true
	} else {
		s.io.HandleError(data)
	}
}

func (s *Server) writeCommand(command byte) {
	s.io.seq = 0
	_ = s.io.writePacket([]byte{
		0x01, //1 byte long
		0x00,
		0x00,
		0x00, //seq
		command,
	})
}

func (s *Server) query(q string) error {
	s.io.seq = 0
	length := len(q) + 1
	data := make([]byte, length+4)
	data[4] = 3
	copy(data[5:], q)
	return s.io.writePacket(data)
}

func (s *Server) Quit() {
	//quit
	s.writeCommand(byte(1))
	//maybe only close
	if err := s.conn.Close(); nil != err {
		fmt.Printf("error in close :%v\n", err)
	}
}


type PacketIo struct {
	r   *bufio.Reader
	w   io.Writer
	seq uint8
}

func (p *PacketIo) readPacket() ([]byte, error) {
	//to read header
	header := []byte{0, 0, 0, 0}
	if _, err := io.ReadFull(p.r, header); err != nil {
		return nil, err
	}

	length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
	if length == 0 {
		p.seq++
		return []byte{}, nil
	}

	if length == 1 {
		return nil, fmt.Errorf("invalid payload")
	}

	seq := uint8(header[3])
	if p.seq != seq {
		return nil, fmt.Errorf("invalid seq %d", seq)
	}

	p.seq++
	data := make([]byte, length)
	if _, err := io.ReadFull(p.r, data); err != nil {
		return nil, err
	} else {
		if length < MaxPayloadLength {
			return data, nil
		}
		var buf []byte
		buf, err = p.readPacket()
		if err != nil {
			return nil, err
		}
		if len(buf) == 0 {
			return data, nil
		} else {
			return append(data, buf...), nil
		}
	}
}

func (p *PacketIo) writePacket(data []byte) error {
	length := len(data) - 4
	if length >= MaxPayloadLength {
		data[0] = 0xff
		data[1] = 0xff
		data[2] = 0xff
		data[3] = p.seq

		if n, err := p.w.Write(data[:4+MaxPayloadLength]); err != nil {
			return fmt.Errorf("write find error")
		} else if n != 4+MaxPayloadLength {
			return fmt.Errorf("not equal max pay load length")
		} else {
			p.seq ++
			length -= MaxPayloadLength
			data = data[MaxPayloadLength:]
		}
	}

	data[0] = byte(length)
	data[1] = byte(length >> 8)
	data[2] = byte(length >> 16)
	data[3] = p.seq

	if n, err := p.w.Write(data); err != nil {
		return errors.New("write find error")
	} else if n != len(data) {
		return errors.New("not equal length")
	} else {
		p.seq ++
		return nil
	}
}

func calPassword(scramble, password []byte) []byte {
	crypt := sha1.New()
	crypt.Write(password)
	stage1 := crypt.Sum(nil)

	crypt.Reset()
	crypt.Write(stage1)
	hash := crypt.Sum(nil)

	crypt.Reset()
	crypt.Write(scramble)
	crypt.Write(hash)
	scramble = crypt.Sum(nil)

	for i := range scramble {
		scramble[i] ^= stage1[i]
	}

	return scramble
}

func (p *PacketIo) HandleError(data []byte) {
	pos := 1
	code := binary.LittleEndian.Uint16(data[pos:])
	pos += 2
	pos++
	state := string(data[pos : pos+5])
	pos += 5
	msg := string(data[pos:])
	fmt.Printf("code:%d, state:%s, msg:%s\n", code, state, msg)
}

  1. main
package main

import (
	"flag"
	"fmt"
	"github.com/igoso/gbinlog/app"
	"os"
	"os/signal"
	"runtime"
	"syscall"
)

var myHost = flag.String("host", "127.0.0.1", "MySQL replication host")
var myPort = flag.Int("port", 3306, "MySQL replication port")
var myUser = flag.String("user", "root", "MySQL replication user")
var myPass = flag.String("pass", "****", "MySQL replication pass")
var serverId = flag.Int("server_id", 1111, "MySQL replication server id")

func main() {
	sc := make(chan os.Signal, 1)
	signal.Notify(sc,
		os.Kill,
		os.Interrupt,
		syscall.SIGHUP,
		syscall.SIGQUIT,
		syscall.SIGINT,
		syscall.SIGTERM,
	)

	runtime.GOMAXPROCS(runtime.NumCPU()/4 + 1)
	flag.Parse()
	cfg := &app.Config{
		*myHost,
		*myPort,
		*myUser,
		*myPass,
		*serverId,
		"mysql-bin.000032",
		3070,
	}
	srv := &app.Server{Cfg: cfg}
	go srv.Run()

	select {
	case n := <-sc:
		srv.Quit()
		fmt.Printf("receive signal %v, closing", n)
	}
}

  1. go.mod 只有一个依赖
module github.com/igoso/gbinlog

go 1.15

require (
	github.com/siddontang/go-mysql v1.1.0
)

其他

注意如果使用binlog dump 连接执行quit命令,在MySQL端查看,不会立刻消失,处在close_wait状态。当下次再次有新的连接过来后,才会消失并建立新的。中间可能有1236:相同对的server_id存在的错误,但不影响使用

本来在尝试自己解析binlog,如果实际做的话工作量还是很大的,以为有很多种类的binlog event需要处理。后来在siddentang的go-mysql包中发现已经有实现了一个很好用的binlogSyncer,其中就有完善的解析方法。包括他实现的binlogSyncer也非常方便,感兴趣的可以参考如下。

package main

import (
	"context"
	"flag"
	"fmt"
	"os"

	"github.com/pingcap/errors"
	"github.com/siddontang/go-mysql/mysql"
	"github.com/siddontang/go-mysql/replication"
)

var host = flag.String("host", "127.0.0.1", "MySQL host")
var port = flag.Int("port", 3306, "MySQL port")
var user = flag.String("user", "root", "MySQL user, must have replication privilege")
var password = flag.String("password", "****", "MySQL password")

var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb")

var file = flag.String("file", "mysql-bin.000032", "Binlog filename")
var pos = flag.Int("pos", 3070, "Binlog position")

var semiSync = flag.Bool("semisync", false, "Support semi sync")
var backupPath = flag.String("backup_path", "", "backup path to store binlog files")

var rawMode = flag.Bool("raw", false, "Use raw mode")

func main() {
	flag.Parse()

	cfg := replication.BinlogSyncerConfig{
		ServerID: 101,
		Flavor:   *flavor,

		Host:            *host,
		Port:            uint16(*port),
		User:            *user,
		Password:        *password,
		RawModeEnabled:  *rawMode,
		SemiSyncEnabled: *semiSync,
		UseDecimal:      true,
	}

	b := replication.NewBinlogSyncer(cfg)

	pos := mysql.Position{Name: *file, Pos: uint32(*pos)}
	if len(*backupPath) > 0 {
		// Backup will always use RawMode.
		err := b.StartBackup(*backupPath, pos, 0)
		if err != nil {
			fmt.Printf("Start backup error: %v\n", errors.ErrorStack(err))
			return
		}
	} else {
		s, err := b.StartSync(pos)
		if err != nil {
			fmt.Printf("Start sync error: %v\n", errors.ErrorStack(err))
			return
		}

		for {
			e, err := s.GetEvent(context.Background())
			if err != nil {
				// Try to output all left events
				events := s.DumpEvents()
				for _, e := range events {
					e.Dump(os.Stdout)
				}
				fmt.Printf("Get event error: %v\n", errors.ErrorStack(err))
				return
			}

			e.Dump(os.Stdout)
		}
	}

}

以上就是本期的全部内容。