无奈!我用go写了个MySQL服务

70 阅读18分钟

一个程序员的“被迫营业”故事

序:同事的困境

嘿,各位程序员大佬们!今天我要给你们讲一个既荒诞又真实的故事——关于我如何被同事的易语言代码“逼”成了“MySQL协议专家”,注意引号,不是真专家。

事情是这样的,我有个同事,有一个程序是用易语言写的,古董级的。最近要做升级,要缓存一个4G左右的数据,可是易语言是32位程序,做不到,于是我就用go开发了一个HTTP接口,数据我缓存,他用易语言调用这个接口,结果遇到了一个世纪难题:易语言使用http读文件不是线程不安全!是的,你没听错,在2025年,还有人在为易语言的线程安全问题挠头。

灵感乍现:“既然易语言搞不定HTTP,那我们为什么不用MySQL协议呢?”

那天,他调试来调试去,线程安全问题就是解决不了!最后,我盯着他足足看了三十秒钟,突然脑子里闪过一道光——对啊!MySQL协议是成熟稳定的数据库协议,几乎所有语言都有完善的驱动支持,包括易语言!

于是,一个大胆的想法诞生了:写一个简单的MySQL协议服务端,将HTTP接口伪装成MySQL查询。这样,同事就可以用易语言通过MySQL驱动轻松调用,完美避开线程安全的坑!

开发历程:从0到1的MySQL协议实现

说干就干!我用Go语言(感谢Go的并发模型!)开始了这个"伪MySQL服务器"的开发。过程中遇到了不少挑战:

  1. 协议解析:MySQL协议虽然公开,但细节繁多,特别是握手认证部分
  2. 连接处理:要支持多客户端同时连接,Go的goroutine正好派上用场
  3. 兼容性处理:不同语言的MySQL驱动初始化时会发送各种查询,必须一一处理

功能特性:麻雀虽小,五脏俱全

经过一番努力,这个"不正经"的MySQL服务器已经具备了以下功能:

  • ✅ MySQL协议兼容,可以被各种语言的MySQL驱动正常连接
  • ✅ 支持基本的数据库操作命令
  • ✅ 处理客户端初始化查询(比如字符集、校对规则等)
  • ✅ 自定义SQL命令支持:这里用了两个示例接口,没有真实的业务代码
  • ✅ 数据库和表的模拟(通过文件系统目录结构)

自定义命令使用示例

GET MB 命令 - 字符串回显功能

mysql> get mb hello world;
+--------+-------------+
| key    | value       |
+--------+-------------+
| result | hello world |
+--------+-------------+
1 row in set

2. GET MD5 命令 - MD5计算功能

mysql> get md5 test123;
+--------+----------------------------------+
| key    | value                            |
+--------+----------------------------------+
| result | 22b75d6007e06f4a959d1b1d69b4c4bd |
+--------+----------------------------------+
1 row in set

核心代码解析

1. 自定义命令处理机制

自定义命令是这个服务的核心功能之一,下面我们来看看它是如何实现的:

// 处理自定义SQL命令
if strings.HasPrefix(strings.ToUpper(sql), "GET MB") {
    parts := strings.Fields(strings.ToUpper(sql))
    var input string
    if len(parts) > 2 {
        input = strings.Join(parts[2:], " ") // 提取命令参数
    } else {
        input = "default" // 默认值处理
    }
    
    // 返回输入的字符串
    mbResult := input
    
    // 构建MySQL响应结构
    rows := [][]string{
        {"result", mbResult},
    }
    
    mr := MysqlResponse{
        // 定义字段结构
        Fs: []FieldProtocol{...},
        Rows: rows,
    }
    c.Conn.Write(mr.GetBytes()) // 发送响应
} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {
    // MD5命令处理逻辑...
} else {
    // 处理标准SQL命令...
}

这段代码展示了自定义命令的处理流程:

  1. 命令识别:通过字符串前缀判断命令类型
  2. 参数解析:提取命令后面的参数部分
  3. 业务处理:执行相应的业务逻辑(字符串回显或MD5计算)
  4. 响应构建:创建标准的MySQL协议响应结构
  5. 数据发送:将响应序列化为二进制数据并发送

2. MySQL协议响应结构

服务通过MysqlResponse结构体来构建符合MySQL协议的响应数据:

// MysqlResponse MySQL响应结构
type MysqlResponse struct {
    Fs   []FieldProtocol // 字段定义
    Rows [][]string      // 数据行
}

// GetBytes 将响应转换为二进制数据
func (mr *MysqlResponse) GetBytes() []byte {
    // 构建响应头部
    var buf bytes.Buffer
    buf.WriteByte(0x01) // 包头
    
    // 写入字段定义
    for _, field := range mr.Fs {
        buf.Write(field.GetBytes())
    }
    
    // 写入EOF包
    buf.Write(getEOFPacket())
    
    // 写入数据行
    for _, row := range mr.Rows {
        buf.Write(getRowPacket(row))
    }
    
    // 写入最终的EOF包
    buf.Write(getEOFPacket())
    
    return buf.Bytes()
}

这个结构严格遵循MySQL的文本协议格式,包含:

  • 字段定义:描述返回数据的结构
  • 数据行:实际的查询结果
  • 特殊标记:如EOF包,用于分隔不同的协议阶段

3. 连接处理与认证流程

服务使用Go的goroutine为每个客户端连接创建独立的处理线程:

// Start 启动连接处理
func (c *ConnHandle) Start() {
    defer func() {
        c.Conn.Close()
        wg.Done()
    }()
    
    // 发送握手包
    c.WriteOnePack()
    
    // 读取并解析客户端响应
    data, err := c.ReadOnePack()
    if err != nil {
        return
    }
    
    // 处理认证
    username, password, dbname, err := parseClientHandshakePacket(data)
    if err != nil || !isPassScrambleMysqlNativePassword(password, c.Salt) {
        c.writeErrorPacket(1045, "28000", "用户名或密码错误")
        return
    }
    
    // 发送认证成功响应
    c.writeOKPacket()
    
    // 主命令处理循环
    for !exit {
        err = c.handleNextCommand()
        if err != nil {
            break
        }
    }
}

这个连接处理流程包括:

  1. 握手初始化:服务器发送握手包给客户端
  2. 身份验证:验证用户名密码(本实现中简化了验证逻辑)
  3. 命令循环:持续接收并处理客户端命令
  4. 资源清理:连接结束时释放资源

4. 配置文件读取

服务通过readConfig函数读取配置文件:

// readConfig 读取配置文件
func readConfig(configPath string) (string, string, error) {
  pz, err := config.ReadDefault(configPath)
	if err != nil {
		Loger.Error(err)
		return err
	}
	if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {
		Loger.Error(err)
		return err
	}
	return nil
}

这个函数负责从配置文件中读取服务配置,目前只有数据目录路径

配置文件如下: my.ini

[mysqld]
datadir=./data

程序源码如下:

package main

import (
	"bytes"
	"crypto/md5"
	"crypto/rand"
	"crypto/sha1"
	"encoding/binary"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io"
	"net"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"time"

	"gitcode.com/jjgtmgx/mgxlog"
	"github.com/larspensjo/config"
)

// 最大数据包大小
const (
	MaxPacketSize   = (1 << 24) - 1 // 服务器支持的最大数据包大小
	ProtocolVersion = 10            // MySQL协议版本,固定为10
)

// 认证方式常量
const (
	// MysqlNativePassword 身份验证方式
	MysqlNativePassword = "mysql_native_password"
)

// 能力标志常量
const (
	// CapabilityClientFoundRows 返回找到的行数而不是受影响的行数
	CapabilityClientFoundRows = 1 << 1

	// CapabilityClientConnectWithDB 可以在连接时指定数据库
	CapabilityClientConnectWithDB = 1 << 3

	// CapabilityClientProtocol41 新的4.1协议,必须支持
	CapabilityClientProtocol41 = 1 << 9

	// CapabilityClientSecureConnection 新的4.1身份验证方式
	CapabilityClientSecureConnection = 1 << 15

	// CapabilityClientMultiStatements 支持在COM_QUERY和COM_STMT_PREPARE中处理多个语句
	CapabilityClientMultiStatements = 1 << 16

	// CapabilityClientPluginAuth 客户端支持插件身份验证
	CapabilityClientPluginAuth = 1 << 19

	// CapabilityClientConnAttr 允许在Protocol::HandshakeResponse41中使用连接属性
	CapabilityClientConnAttr = 1 << 20

	// CapabilityClientDeprecateEOF 期望在文本结果集的行之后使用OK(而不是EOF)
	CapabilityClientDeprecateEOF = 1 << 24
)

// 数据包类型常量
const (
	// ComQuit 客户端请求关闭连接
	ComQuit = 0x01

	// ComInitDB 客户端请求切换数据库
	ComInitDB = 0x02

	// ComQuery 客户端发送SQL查询
	ComQuery = 0x03

	// ComPing 客户端发送ping请求
	ComPing = 0x0e

	// ComSetOption 客户端设置选项
	ComSetOption = 0x1b

	// OKPacket OK数据包的头部标识
	OKPacket = 0x00

	// EOFPacket EOF数据包的头部标识
	EOFPacket = 0xfe

	// ErrPacket 错误数据包的头部标识
	ErrPacket = 0xff
)

// 日志记录
var Loger, _ = mgxlog.NewMgxLog("runlog/", 10*1024*1024, 100, 3, 1000)


var exit bool = false
var wg sync.WaitGroup
var tidchan = make(chan uint32)

var dbRoot string

func main() {
	defer Loger.Flush()
	addr := flag.String("addr", ":3307", "http service address")
	config := flag.String("config", "./my.ini", "configuration file path")
	flag.Parse()

	// 读取 my.ini 获取 datadir
	if err := readConfig(*config); err != nil {
		return
	}

	go CreateTid()
	wg.Add(1)
	go StartServer(*addr)
	for {
		var cmd string
		fmt.Scanf("%s", &cmd)
		if cmd == "exit" {
			exit = true
			break
		}
		fmt.Println("未知命令")
		fmt.Println("exit 退出程序")
	}
	wg.Wait()
}

// readConfig 读取配置文件
func readConfig(configPath string) error {
	pz, err := config.ReadDefault(configPath)
	if err != nil {
		Loger.Error(err)
		return err
	}
	if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {
		Loger.Error(err)
		return err
	}
	return nil
}

// StartServer 启动服务器
func StartServer(addr string) {
	defer wg.Done()
	var netListen net.Listener
	for !exit {
		if netListen == nil {
			var err error
			if netListen, err = net.Listen("tcp", addr); err != nil {
				Loger.Error(err)
			} else {
				go func() {
					for !exit {
						conn, err := netListen.Accept()
						if err != nil {
							continue
						}
						ch := ConnHandle{Conn: conn}
						go ch.Start()
					}
				}()
			}
		} else {
			time.Sleep(2 * time.Second)
		}
	}
	if netListen != nil {
		netListen.Close()
	}
}

// CreateTid 生成事务ID
func CreateTid() {
	tid := uint32(1)
	for {
		tidchan <- tid
		tid++
		if tid > 999999999 {
			tid = 1
		}
	}
}

var ServerVersion = "5.5.15"

type ConnHandle struct {
	Conn         net.Conn
	sequence     uint8
	Capabilities uint32
	SchemaName   string
	CharacterSet uint8
	User         string
}

// Start 启动连接处理
func (ch *ConnHandle) Start() {
	defer ch.Conn.Close()
	salt, err := ch.WriteOnePack()
	if err != nil {
		Loger.Error(ch.Conn.RemoteAddr().String(), err)
		return
	}
	b, err := ch.ReadOnePack()
	if err != nil {
		Loger.Error(ch.Conn.RemoteAddr().String(), err)
		return
	}
	user, _, authResponse, err := ch.parseClientHandshakePacket(true, b)
	if err != nil {
		Loger.Errorf("无法解析来自 %s 的客户端握手响应: %v", ch.Conn, err)
		return
	}
	if !isPassScrambleMysqlNativePassword(authResponse, salt, "root") {
		ch.writeErrorPacket(1045, "00001", "用户名或密码错误")
		return
	}
	ch.User = user
	if err := ch.writeOKPacket(0, 0, 0, 0); err != nil {
		Loger.Errorf("无法向 %s 写入OK包: %v", ch.Conn, err)
		return
	}
	for !exit {
		err := ch.handleNextCommand()
		if err != nil {
			Loger.Error(err)
			return
		}
	}
}

// handleNextCommand 处理下一个命令
func (c *ConnHandle) handleNextCommand() error {
	c.sequence = 0
	data, err := c.readEphemeralPacket()
	if err != nil {
		return err
	}
	switch data[0] {
	case ComQuit:
		return errors.New("ComQuit")
	case ComInitDB:
		dbName := string(data[1:])
		if _, err := os.Stat(filepath.Join(dbRoot, dbName)); os.IsNotExist(err) {
			c.writeErrorPacket(1049, "42000", fmt.Sprintf("未知数据库 '%s'", dbName))
			return nil
		}
		c.SchemaName = dbName
		c.writeOKPacket(0, 0, 0, 0)
	case ComQuery:
		sql := string(data[1:])
		sql = collapseSpaces(sql)
		sql = strings.TrimRight(sql, ";")
		sql = strings.TrimSpace(sql)
		sql = strings.Join(strings.Fields(sql), " ")
		// 处理自定义SQL命令: get mb xxx
		if strings.HasPrefix(strings.ToUpper(sql), "GET MB") {
			parts := strings.Fields(strings.ToUpper(sql))
			var input string
			if len(parts) > 2 {
				input = strings.Join(parts[2:], " ")
			} else {
				input = "default"
			}
			
			// 返回输入的字符串
			mbResult := input
			
			rows := [][]string{
				{"result", mbResult},
			}
			
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "",
						Table:         "",
						OriginalTable: "",
						Name:          "key",
						OriginalName:  "key",
						Charset:       33,
						Length:        50,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
					{
						Catalog:       "def",
						Database:      "",
						Table:         "",
						OriginalTable: "",
						Name:          "value",
						OriginalName:  "value",
						Charset:       33,
						Length:        50,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {
			parts := strings.Fields(strings.ToUpper(sql))
			var input string
			if len(parts) > 2 {
				input = strings.Join(parts[2:], " ")
			} else {
				input = "default"
			}
			
			// 计算MD5值
			hash := md5.Sum([]byte(input))
			md5Result := hex.EncodeToString(hash[:])
			
			rows := [][]string{
				{"result", md5Result},
			}
			
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "",
						Table:         "",
						OriginalTable: "",
						Name:          "key",
						OriginalName:  "key",
						Charset:       33,
						Length:        50,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
					{
						Catalog:       "def",
						Database:      "",
						Table:         "",
						OriginalTable: "",
						Name:          "value",
						OriginalName:  "value",
						Charset:       33,
						Length:        50,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else if strings.HasPrefix(strings.ToUpper(sql), "SELECT COLLATIONS") {
			rows := [][]string{}
			for i := 0; i < 1000; i++ {
				rows = append(rows, []string{strconv.Itoa(i)})
			}
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "COLLATIONS",
						OriginalTable: "COLLATIONS",
						Name:          "result",
						OriginalName:  "result",
						Charset:       33,
						Length:        96,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else if strings.HasPrefix(strings.ToUpper(sql), "SHOW DATABASES") {
			dbs, err := os.ReadDir(dbRoot)
			if err != nil {
				c.writeErrorPacket(1049, "HY000", "无法读取数据库目录")
				return nil
			}
			var rows [][]string
			for _, entry := range dbs {
				if entry.IsDir() {
					rows = append(rows, []string{entry.Name()})
				}
			}
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "SCHEMATA",
						OriginalTable: "SCHEMATA",
						Name:          "Database",
						OriginalName:  "SCHEMA_NAME",
						Charset:       33,
						Length:        192,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else if strings.HasPrefix(sql, "SELECT @@character_set_database, @@collation_database") {
			rows := [][]string{
				{"utf8mb4", "utf8mb4_general_ci"},
			}
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "SCHEMATA",
						OriginalTable: "SCHEMATA",
						Name:          "@@character_set_database",
						OriginalName:  "@@character_set_database",
						Charset:       33,
						Length:        192,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "SCHEMATA",
						OriginalTable: "SCHEMATA",
						Name:          "@@collation_database",
						OriginalName:  "@@collation_database",
						Charset:       33,
						Length:        192,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else if strings.HasPrefix(sql, "SHOW FULL TABLES") {
			if c.SchemaName == "" {
				c.writeErrorPacket(1046, "3D000", "未选择数据库")
				return nil
			}
			dbPath := filepath.Join(dbRoot, c.SchemaName)
			entries, err := os.ReadDir(dbPath)
			if err != nil {
				c.writeErrorPacket(1049, "42000", fmt.Sprintf("Unknown database '%s'", c.SchemaName))
				return nil
			}
			var rows [][]string
			for _, entry := range entries {
				if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".mgx") {
					tableName := strings.TrimSuffix(entry.Name(), ".mgx")
					rows = append(rows, []string{tableName, "BASE TABLE"})
				}
			}
			mr := MysqlResponse{
				Fs: []FieldProtocol{
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "SCHEMATA",
						OriginalTable: "SCHEMATA",
						Name:          "Tables_in_" + c.SchemaName,
						OriginalName:  "Tables_in_" + c.SchemaName,
						Charset:       33,
						Length:        192,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
					{
						Catalog:       "def",
						Database:      "information_schema",
						Table:         "SCHEMATA",
						OriginalTable: "SCHEMATA",
						Name:          "Table_type",
						OriginalName:  "Table_type",
						Charset:       33,
						Length:        192,
						Type:          253,
						Flags:         1,
						Decimals:      0,
					},
				},
				Rows: rows,
			}
			c.Conn.Write(mr.GetBytes())
		} else {
			c.writeOKPacket(0, 0, 0, 0)
		}
	case ComPing:
		c.writeOKPacket(0, 0, 0, 0)
	case ComSetOption:
	default:
	}
	return nil
}

// collapseSpaces 压缩字符串中的空格
func collapseSpaces(input string) string {
	var buf bytes.Buffer
	reader := strings.NewReader(input)
	prevWasSpace := false

	for {
		r, _, err := reader.ReadRune()
		if err != nil {
			break
		}
		if r == ' ' {
			if !prevWasSpace {
				buf.WriteRune(r)
			}
			prevWasSpace = true
		} else {
			buf.WriteRune(r)
			prevWasSpace = false
		}
	}

	return buf.String()
}

// writeOKPacket 写入OK数据包
func (ch *ConnHandle) writeOKPacket(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {
	buf := bytes.NewBuffer([]byte{})
	buf.WriteByte(OKPacket)
	binary.Write(buf, binary.LittleEndian, affectedRows)
	binary.Write(buf, binary.LittleEndian, lastInsertID)
	binary.Write(buf, binary.LittleEndian, flags)
	binary.Write(buf, binary.LittleEndian, warnings)
	return ch.writePacket(buf.Bytes())
}

// writeErrorPacket 写入错误数据包
func (ch *ConnHandle) writeErrorPacket(errorCode uint16, sqlState, errMessage string) error {
	buf := bytes.NewBuffer([]byte{})
	buf.WriteByte(ErrPacket)
	binary.Write(buf, binary.LittleEndian, errorCode)
	buf.WriteByte('#')
	buf.WriteString(sqlState)
	buf.WriteString(errMessage)
	return ch.writePacket(buf.Bytes())
}

// parseClientHandshakePacket 解析客户端握手数据包
func (ch *ConnHandle) parseClientHandshakePacket(firstTime bool, data []byte) (string, string, []byte, error) {
	pos := 0
	clientFlags, pos, ok := readUint32(data, pos)
	if !ok {
		return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取客户端标志")
	}
	if clientFlags&CapabilityClientProtocol41 == 0 {
		return "", "", nil, errors.New("parseClientHandshakePacket: 仅支持协议4.1")
	}
	if firstTime {
		ch.Capabilities = clientFlags & (CapabilityClientDeprecateEOF | CapabilityClientFoundRows)
	}
	if clientFlags&CapabilityClientMultiStatements > 0 {
		ch.Capabilities |= CapabilityClientMultiStatements
	}

	_, pos, ok = readUint32(data, pos)
	if !ok {
		return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取maxPacketSize")
	}

	characterSet, pos, ok := readByte(data, pos)
	if !ok {
		return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取characterSet")
	}
	ch.CharacterSet = characterSet

	pos += 23

	username, pos, ok := readNullString(data, pos)
	if !ok {
		return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取username")
	}

	var authResponse []byte
	// 只处理安全连接方式的身份验证响应
	if clientFlags&CapabilityClientSecureConnection != 0 {
		var l byte
		l, pos, ok = readByte(data, pos)
		if !ok {
			return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response长度")
		}
		authResponse, pos, ok = readBytesCopy(data, pos, int(l))
		if !ok {
			return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")
		}
	} else {
		a := ""
		a, pos, ok = readNullString(data, pos)
		if !ok {
			return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")
		}
		authResponse = []byte(a)
	}

	if clientFlags&CapabilityClientConnectWithDB != 0 {
		dbname := ""
		dbname, pos, ok = readNullString(data, pos)
		if !ok {
			return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取dbname")
		}
		ch.SchemaName = dbname
	}

	authMethod := MysqlNativePassword
	if clientFlags&CapabilityClientPluginAuth != 0 {
		authMethod, pos, ok = readNullString(data, pos)
		if !ok {
			return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取authMethod")
		}
	}
	if authMethod == "" {
		authMethod = MysqlNativePassword
	}

	if clientFlags&CapabilityClientConnAttr != 0 {
		if _, _, err := parseConnAttrs(data, pos); err != nil {
			Loger.Error("解码客户端发送的连接属性: ", err)
		}
	}

	return username, authMethod, authResponse, nil
}

// ReadOnePack 读取一个数据包
func (ch *ConnHandle) ReadOnePack() ([]byte, error) {
	var r io.Reader = ch.Conn.(io.Reader)

	length, err := ch.readHeaderFrom(r)
	if err != nil {
		return nil, err
	}
	if length < MaxPacketSize {
		buf := make([]byte, length)
		if _, err := io.ReadFull(r, buf); err != nil {
			return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
		}
		return buf, nil
	}
	return nil, errors.New("readEphemeralPacketDirect doesn't support more than one packet")
}

// WriteOnePack 写入一个数据包
func (ch *ConnHandle) WriteOnePack() ([]byte, error) {
	salt, _ := NewSalt()
	p01 := Packet01{
		Ver:      10,
		VerSion:  ServerVersion,
		ServerId: <-tidchan,
		Salt:     salt,
		SerFlag1: []byte{255, 247},
		Bm:       28,
		SerType:  0,
		SerFlag2: []byte{15, 128},
	}
	err := ch.writePacket(p01.GetBytes())
	return salt, err
}

// writePacket 写入数据包
func (ch *ConnHandle) writePacket(data []byte) error {
	index := 0
	length := len(data)
	w := ch.Conn.(*net.TCPConn)
	for {
		packetLength := length
		if packetLength > MaxPacketSize {
			packetLength = MaxPacketSize
		}

		var header [4]byte
		header[0] = byte(packetLength)
		header[1] = byte(packetLength >> 8)
		header[2] = byte(packetLength >> 16)
		header[3] = ch.sequence
		if n, err := w.Write(header[:]); err != nil {
			return errors.New("Write(header) failed")
		} else if n != 4 {
			return errors.New("Write(header) returned a short write: < 4")
		}

		if n, err := w.Write(data[index : index+packetLength]); err != nil {
			return errors.New("Write(packet) failed")
		} else if n != packetLength {
			return errors.New("Write(packet) returned a short write")
		}

		ch.sequence++
		length -= packetLength
		if length == 0 {
			if packetLength == MaxPacketSize {
				header[0] = 0
				header[1] = 0
				header[2] = 0
				header[3] = ch.sequence
				if n, err := w.Write(header[:]); err != nil {
					return errors.New("Write(empty header) failed")
				} else if n != 4 {
					return errors.New("Write(empty header) returned a short write")
				}
				ch.sequence++
			}
			return nil
		}
		index += packetLength
	}
}

// readHeaderFrom 从读取器中读取头部
func (ch *ConnHandle) readHeaderFrom(r io.Reader) (int, error) {
	var header [4]byte
	if _, err := io.ReadFull(r, header[:]); err != nil {
		if err == io.EOF {
			return 0, err
		}
		if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
			return 0, io.EOF
		}
		return 0, errors.New("io.ReadFull(header size) failed")
	}

	sequence := uint8(header[3])
	if sequence != ch.sequence {
		return 0, errors.New("invalid sequence: expected " + strconv.Itoa(int(ch.sequence)) + ", got " + strconv.Itoa(int(sequence)))
	}
	ch.sequence++
	return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16), nil
}

// readEphemeralPacket 读取临时数据包
func (c *ConnHandle) readEphemeralPacket() ([]byte, error) {
	var r io.Reader = c.Conn.(io.Reader)
	length, err := c.readHeaderFrom(r)
	if err != nil {
		return nil, err
	}
	if length == 0 {
		return nil, nil
	}

	data := make([]byte, length)
	if _, err := io.ReadFull(r, data); err != nil {
		return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
	}

	if length < MaxPacketSize {
		return data, nil
	}

	for {
		next, err := c.readOnePacket()
		if err != nil {
			return nil, err
		}
		if len(next) == 0 {
			break
		}
		data = append(data, next...)
		if len(next) < MaxPacketSize {
			break
		}
	}
	return data, nil
}

// readOnePacket 读取一个数据包
func (ch *ConnHandle) readOnePacket() ([]byte, error) {
	var r io.Reader = ch.Conn.(io.Reader)
	length, err := ch.readHeaderFrom(r)
	if err != nil {
		return nil, err
	}
	if length == 0 {
		return nil, nil
	}
	data := make([]byte, length)
	if _, err := io.ReadFull(r, data); err != nil {
		return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
	}
	return data, nil
}

// parseConnAttrs 解析连接属性
func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
	var attrLen uint64
	attrLen, pos, ok := readLenEncInt(data, pos)
	if !ok {
		return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性变量长度")
	}

	var attrLenRead uint64
	attrs := make(map[string]string)

	for attrLenRead < attrLen {
		var keyLen byte
		keyLen, pos, ok = readByte(data, pos)
		if !ok {
			return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键长度")
		}
		attrLenRead += uint64(keyLen) + 1

		var connAttrKey []byte
		connAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen))
		if !ok {
			return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键")
		}

		var valLen byte
		valLen, pos, ok = readByte(data, pos)
		if !ok {
			return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值长度")
		}
		attrLenRead += uint64(valLen) + 1

		var connAttrVal []byte
		connAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen))
		if !ok {
			return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值")
		}

		attrs[string(connAttrKey[:])] = string(connAttrVal[:])
	}

	return attrs, pos, nil
}

// isPassScrambleMysqlNativePassword 验证密码是否匹配
func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword string) bool {
	if len(reply) == 0 {
		return false
	}
	if mysqlNativePassword == "" {
		return false
	}
	mysqlNativePassword = NativePassword(mysqlNativePassword)
	if strings.Contains(mysqlNativePassword, "*") {
		mysqlNativePassword = mysqlNativePassword[1:]
	}

	hash, err := hex.DecodeString(mysqlNativePassword)
	if err != nil {
		return false
	}
	
	crypt := sha1.New()
	crypt.Write(salt)
	crypt.Write(hash)
	scramble := crypt.Sum(nil)

	for i := range scramble {
		scramble[i] ^= reply[i]
	}
	hashStage1 := scramble

	crypt.Reset()
	crypt.Write(hashStage1)
	candidateHash2 := crypt.Sum(nil)

	return bytes.Equal(candidateHash2, hash)
}

// NativePassword 生成原生密码格式
func NativePassword(password string) string {
	if len(password) == 0 {
		return ""
	}

	hash := sha1.New()
	hash.Write([]byte(password))
	s1 := hash.Sum(nil)

	hash.Reset()
	hash.Write(s1)
	s2 := hash.Sum(nil)

	s := strings.ToUpper(hex.EncodeToString(s2))

	return fmt.Sprintf("*%s", s)
}

// NewSalt 生成新的盐值
func NewSalt() ([]byte, error) {
	salt := make([]byte, 20)
	if _, err := rand.Read(salt); err != nil {
		return nil, err
	}

	for i := 0; i < len(salt); i++ {
		salt[i] &= 0x7f
		if salt[i] == '\x00' || salt[i] == '$' {
			salt[i]++
		}
	}

	return salt, nil
}

// Packet01 数据包结构

type Packet01 struct {
	Ver      byte
	VerSion  string
	ServerId uint32
	Salt     []byte
	SerFlag1 []byte
	Bm       int8
	SerType  int16
	SerFlag2 []byte
}

// GetBytes 获取数据包的字节表示
func (p01 *Packet01) GetBytes() []byte {
	b := bytes.NewBuffer([]byte{})
	b.WriteByte(p01.Ver)
	b.WriteString(p01.VerSion)
	b.WriteByte(byte(0))
	binary.Write(b, binary.BigEndian, p01.ServerId)
	b.Write(p01.Salt[:8])
	b.WriteByte(byte(0))
	binary.Write(b, binary.BigEndian, p01.SerFlag1)
	b.WriteByte(byte(p01.Bm))
	binary.Write(b, binary.BigEndian, p01.SerType)
	binary.Write(b, binary.BigEndian, p01.SerFlag2)
	b.WriteByte(byte(21))
	b.Write(bytes.Repeat([]byte{0}, 10))
	b.Write(p01.Salt[8:])
	b.WriteByte(byte(0))
	b.WriteString(MysqlNativePassword)
	b.WriteByte(byte(0))
	return b.Bytes()
}

// MysqlResponse MySQL响应结构
type MysqlResponse struct {
	Fs   []FieldProtocol
	Rows [][]string
}

// FieldProtocol 字段协议结构
type FieldProtocol struct {
	Catalog       string
	Database      string
	Table         string
	OriginalTable string
	Name          string
	OriginalName  string
	Charset       int
	Length        int
	Type          int
	Flags         int
	Decimals      int
}

// GetBytes 获取字段协议的字节表示
func (fp *FieldProtocol) GetBytes(pk int) []byte {
	b := bytes.NewBuffer([]byte{0, 0, 0, 0})
	lt := EncodeLength(len(fp.Catalog))
	b.Write(lt)
	b.WriteString(fp.Catalog)
	lt = EncodeLength(len(fp.Database))
	b.Write(lt)
	b.WriteString(fp.Database)
	lt = EncodeLength(len(fp.Table))
	b.Write(lt)
	b.WriteString(fp.Table)
	lt = EncodeLength(len(fp.OriginalTable))
	b.Write(lt)
	b.WriteString(fp.OriginalTable)
	lt = EncodeLength(len(fp.Name))
	b.Write(lt)
	b.WriteString(fp.Name)
	lt = EncodeLength(len(fp.OriginalName))
	b.Write(lt)
	b.WriteString(fp.OriginalName)
	b.WriteByte(0x0c)
	b.WriteByte(byte(fp.Charset))
	b.WriteByte(0x00)
	b.Write(intToBytesLittleEndian(fp.Length))
	b.WriteByte(byte(fp.Type))
	lt = intToBytesLittleEndian(fp.Flags)
	b.WriteByte(lt[0])
	b.WriteByte(lt[1])
	b.WriteByte(byte(fp.Decimals))
	b.WriteByte(0x00)
	b.WriteByte(0x00)
	bs := b.Bytes()
	lb := intToBytesLittleEndian(len(bs) - 4)
	bs[0] = byte(lb[0])
	bs[1] = byte(lb[1])
	bs[2] = byte(lb[2])
	bs[3] = byte(pk)
	return bs
}

// GetBytes 获取MySQL响应的字节表示
func (mr *MysqlResponse) GetBytes() []byte {
	pk := 1
	b := bytes.NewBuffer([]byte{})
	numberFields := EncodeLength(len(mr.Fs))
	numberPackets := intToBytesLittleEndian(len(numberFields))
	numberPackets[3] = byte(pk)
	b.Write(numberPackets)
	b.Write(numberFields)

	for _, f := range mr.Fs {
		pk++
		if pk > 255 {
			pk = 0
		}
		b.Write(f.GetBytes(pk))
	}
	pk++
	if pk > 255 {
		pk = 0
	}
	b.Write(getEOFPacket(pk))
	for _, r := range mr.Rows {
		pk++
		if pk > 255 {
			pk = 0
		}
		b.Write(getRowPacket(r, pk))
	}
	pk++
	if pk > 255 {
		pk = 0
	}
	b.Write(getEOFPacket(pk))
	return b.Bytes()
}

// getRowPacket 获取行数据包
func getRowPacket(strs []string, pk int) []byte {
	b := bytes.NewBuffer([]byte{0, 0, 0, 0})
	for _, s := range strs {
		lt := EncodeLength(len(s))
		b.Write(lt)
		b.WriteString(s)
	}
	bs := b.Bytes()
	lb := intToBytesLittleEndian(len(bs) - 4)
	bs[0] = byte(lb[0])
	bs[1] = byte(lb[1])
	bs[2] = byte(lb[2])
	bs[3] = byte(pk)
	return bs
}

// getEOFPacket 获取EOF数据包
func getEOFPacket(pk int) []byte {
	return []byte{0x05, 0x00, 0x00, byte(pk), 0xFE, 0x00, 0x00, 0x22, 0x00}
}

// EncodeLength 编码长度
func EncodeLength(length int) []byte {
	var buf bytes.Buffer

	if length < 251 {
		buf.WriteByte(byte(length))
	} else if length <= 65535 {
		buf.WriteByte(0xfd)
		binary.Write(&buf, binary.LittleEndian, uint16(length))
	} else {
		buf.WriteByte(0xfe)
		binary.Write(&buf, binary.LittleEndian, uint64(length))
	}

	return buf.Bytes()
}

// intToBytesLittleEndian 将整数转换为小端字节序
func intToBytesLittleEndian(num int) []byte {
	var buf bytes.Buffer
	binary.Write(&buf, binary.LittleEndian, uint32(num))
	return buf.Bytes()
}

// readByte 读取一个字节
func readByte(data []byte, pos int) (byte, int, bool) {
	if pos >= len(data) {
		return 0, 0, false
	}
	return data[pos], pos + 1, true
}

// readBytesCopy 读取并复制字节
func readBytesCopy(data []byte, pos int, size int) ([]byte, int, bool) {
	if pos+size > len(data) {
		return nil, 0, false
	}
	result := make([]byte, size)
	copy(result, data[pos:pos+size])
	return result, pos + size, true
}

// readNullString 读取以null结尾的字符串
func readNullString(data []byte, pos int) (string, int, bool) {
	end := bytes.IndexByte(data[pos:], 0)
	if end == -1 {
		return "", 0, false
	}
	return string(data[pos : pos+end]), pos + end + 1, true
}

// readUint32 读取uint32
func readUint32(data []byte, pos int) (uint32, int, bool) {
	if pos+4 > len(data) {
		return 0, 0, false
	}
	return binary.LittleEndian.Uint32(data[pos : pos+4]), pos + 4, true
}

// readLenEncInt 读取长度编码的整数
func readLenEncInt(data []byte, pos int) (uint64, int, bool) {
	if pos >= len(data) {
		return 0, 0, false
	}
	switch data[pos] {
	case 0xfc:
		if pos+3 > len(data) {
			return 0, 0, false
		}
		return uint64(data[pos+1]) | uint64(data[pos+2])<<8, pos + 3, true
	case 0xfd:
		if pos+4 > len(data) {
			return 0, 0, false
		}
		return uint64(data[pos+1]) |
			uint64(data[pos+2])<<8 |
			uint64(data[pos+3])<<16, pos + 4, true
	case 0xfe:
		if pos+9 > len(data) {
			return 0, 0, false
		}
		return uint64(data[pos+1]) |
			uint64(data[pos+2])<<8 |
			uint64(data[pos+3])<<16 |
			uint64(data[pos+4])<<24 |
			uint64(data[pos+5])<<32 |
			uint64(data[pos+6])<<40 |
			uint64(data[pos+7])<<48 |
			uint64(data[pos+8])<<56, pos + 9, true
	}
	return uint64(data[pos]), pos + 1, true
}

写在最后:当程序员被逼急了...

这个项目让我深刻体会到:程序员的创造力往往来自于解决同事的"奇怪需求"。谁能想到,一个易语言的线程安全问题,最终会催生一个MySQL协议服务器?

所以,下次如果你的同事提出一个看似不合理的需求,不妨换个角度想想——也许这正是你提升技能、探索未知领域的好机会!

最后,如果你也有类似的"被逼无奈"的开发经历,欢迎在评论区分享,让我们一起吐槽,不对,让我们一起学习!


P.S. 同事的易语言项目现在运行得很稳定,他对我说:"还有很多地方可以这样优化,帮我再写几个接口吧!" 我沉默不语......

其他部分文章列表(注意:来源于本人csdn原创)