一个程序员的“被迫营业”故事
序:同事的困境
嘿,各位程序员大佬们!今天我要给你们讲一个既荒诞又真实的故事——关于我如何被同事的易语言代码“逼”成了“MySQL协议专家”,注意引号,不是真专家。
事情是这样的,我有个同事,有一个程序是用易语言写的,古董级的。最近要做升级,要缓存一个4G左右的数据,可是易语言是32位程序,做不到,于是我就用go开发了一个HTTP接口,数据我缓存,他用易语言调用这个接口,结果遇到了一个世纪难题:易语言使用http读文件不是线程不安全!是的,你没听错,在2025年,还有人在为易语言的线程安全问题挠头。
灵感乍现:“既然易语言搞不定HTTP,那我们为什么不用MySQL协议呢?”
那天,他调试来调试去,线程安全问题就是解决不了!最后,我盯着他足足看了三十秒钟,突然脑子里闪过一道光——对啊!MySQL协议是成熟稳定的数据库协议,几乎所有语言都有完善的驱动支持,包括易语言!
于是,一个大胆的想法诞生了:写一个简单的MySQL协议服务端,将HTTP接口伪装成MySQL查询。这样,同事就可以用易语言通过MySQL驱动轻松调用,完美避开线程安全的坑!
开发历程:从0到1的MySQL协议实现
说干就干!我用Go语言(感谢Go的并发模型!)开始了这个"伪MySQL服务器"的开发。过程中遇到了不少挑战:
- 协议解析:MySQL协议虽然公开,但细节繁多,特别是握手认证部分
- 连接处理:要支持多客户端同时连接,Go的goroutine正好派上用场
- 兼容性处理:不同语言的MySQL驱动初始化时会发送各种查询,必须一一处理
功能特性:麻雀虽小,五脏俱全
经过一番努力,这个"不正经"的MySQL服务器已经具备了以下功能:
- ✅ MySQL协议兼容,可以被各种语言的MySQL驱动正常连接
- ✅ 支持基本的数据库操作命令
- ✅ 处理客户端初始化查询(比如字符集、校对规则等)
- ✅ 自定义SQL命令支持:这里用了两个示例接口,没有真实的业务代码
- ✅ 数据库和表的模拟(通过文件系统目录结构)
自定义命令使用示例
GET MB 命令 - 字符串回显功能
mysql> get mb hello world;
+--------+-------------+
| key | value |
+--------+-------------+
| result | hello world |
+--------+-------------+
1 row in set
2. GET MD5 命令 - MD5计算功能
mysql> get md5 test123;
+--------+----------------------------------+
| key | value |
+--------+----------------------------------+
| result | 22b75d6007e06f4a959d1b1d69b4c4bd |
+--------+----------------------------------+
1 row in set
核心代码解析
1. 自定义命令处理机制
自定义命令是这个服务的核心功能之一,下面我们来看看它是如何实现的:
// 处理自定义SQL命令
if strings.HasPrefix(strings.ToUpper(sql), "GET MB") {
parts := strings.Fields(strings.ToUpper(sql))
var input string
if len(parts) > 2 {
input = strings.Join(parts[2:], " ") // 提取命令参数
} else {
input = "default" // 默认值处理
}
// 返回输入的字符串
mbResult := input
// 构建MySQL响应结构
rows := [][]string{
{"result", mbResult},
}
mr := MysqlResponse{
// 定义字段结构
Fs: []FieldProtocol{...},
Rows: rows,
}
c.Conn.Write(mr.GetBytes()) // 发送响应
} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {
// MD5命令处理逻辑...
} else {
// 处理标准SQL命令...
}
这段代码展示了自定义命令的处理流程:
- 命令识别:通过字符串前缀判断命令类型
- 参数解析:提取命令后面的参数部分
- 业务处理:执行相应的业务逻辑(字符串回显或MD5计算)
- 响应构建:创建标准的MySQL协议响应结构
- 数据发送:将响应序列化为二进制数据并发送
2. MySQL协议响应结构
服务通过MysqlResponse结构体来构建符合MySQL协议的响应数据:
// MysqlResponse MySQL响应结构
type MysqlResponse struct {
Fs []FieldProtocol // 字段定义
Rows [][]string // 数据行
}
// GetBytes 将响应转换为二进制数据
func (mr *MysqlResponse) GetBytes() []byte {
// 构建响应头部
var buf bytes.Buffer
buf.WriteByte(0x01) // 包头
// 写入字段定义
for _, field := range mr.Fs {
buf.Write(field.GetBytes())
}
// 写入EOF包
buf.Write(getEOFPacket())
// 写入数据行
for _, row := range mr.Rows {
buf.Write(getRowPacket(row))
}
// 写入最终的EOF包
buf.Write(getEOFPacket())
return buf.Bytes()
}
这个结构严格遵循MySQL的文本协议格式,包含:
- 字段定义:描述返回数据的结构
- 数据行:实际的查询结果
- 特殊标记:如EOF包,用于分隔不同的协议阶段
3. 连接处理与认证流程
服务使用Go的goroutine为每个客户端连接创建独立的处理线程:
// Start 启动连接处理
func (c *ConnHandle) Start() {
defer func() {
c.Conn.Close()
wg.Done()
}()
// 发送握手包
c.WriteOnePack()
// 读取并解析客户端响应
data, err := c.ReadOnePack()
if err != nil {
return
}
// 处理认证
username, password, dbname, err := parseClientHandshakePacket(data)
if err != nil || !isPassScrambleMysqlNativePassword(password, c.Salt) {
c.writeErrorPacket(1045, "28000", "用户名或密码错误")
return
}
// 发送认证成功响应
c.writeOKPacket()
// 主命令处理循环
for !exit {
err = c.handleNextCommand()
if err != nil {
break
}
}
}
这个连接处理流程包括:
- 握手初始化:服务器发送握手包给客户端
- 身份验证:验证用户名密码(本实现中简化了验证逻辑)
- 命令循环:持续接收并处理客户端命令
- 资源清理:连接结束时释放资源
4. 配置文件读取
服务通过readConfig函数读取配置文件:
// readConfig 读取配置文件
func readConfig(configPath string) (string, string, error) {
pz, err := config.ReadDefault(configPath)
if err != nil {
Loger.Error(err)
return err
}
if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {
Loger.Error(err)
return err
}
return nil
}
这个函数负责从配置文件中读取服务配置,目前只有数据目录路径
配置文件如下: my.ini
[mysqld]
datadir=./data
程序源码如下:
package main
import (
"bytes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"gitcode.com/jjgtmgx/mgxlog"
"github.com/larspensjo/config"
)
// 最大数据包大小
const (
MaxPacketSize = (1 << 24) - 1 // 服务器支持的最大数据包大小
ProtocolVersion = 10 // MySQL协议版本,固定为10
)
// 认证方式常量
const (
// MysqlNativePassword 身份验证方式
MysqlNativePassword = "mysql_native_password"
)
// 能力标志常量
const (
// CapabilityClientFoundRows 返回找到的行数而不是受影响的行数
CapabilityClientFoundRows = 1 << 1
// CapabilityClientConnectWithDB 可以在连接时指定数据库
CapabilityClientConnectWithDB = 1 << 3
// CapabilityClientProtocol41 新的4.1协议,必须支持
CapabilityClientProtocol41 = 1 << 9
// CapabilityClientSecureConnection 新的4.1身份验证方式
CapabilityClientSecureConnection = 1 << 15
// CapabilityClientMultiStatements 支持在COM_QUERY和COM_STMT_PREPARE中处理多个语句
CapabilityClientMultiStatements = 1 << 16
// CapabilityClientPluginAuth 客户端支持插件身份验证
CapabilityClientPluginAuth = 1 << 19
// CapabilityClientConnAttr 允许在Protocol::HandshakeResponse41中使用连接属性
CapabilityClientConnAttr = 1 << 20
// CapabilityClientDeprecateEOF 期望在文本结果集的行之后使用OK(而不是EOF)
CapabilityClientDeprecateEOF = 1 << 24
)
// 数据包类型常量
const (
// ComQuit 客户端请求关闭连接
ComQuit = 0x01
// ComInitDB 客户端请求切换数据库
ComInitDB = 0x02
// ComQuery 客户端发送SQL查询
ComQuery = 0x03
// ComPing 客户端发送ping请求
ComPing = 0x0e
// ComSetOption 客户端设置选项
ComSetOption = 0x1b
// OKPacket OK数据包的头部标识
OKPacket = 0x00
// EOFPacket EOF数据包的头部标识
EOFPacket = 0xfe
// ErrPacket 错误数据包的头部标识
ErrPacket = 0xff
)
// 日志记录
var Loger, _ = mgxlog.NewMgxLog("runlog/", 10*1024*1024, 100, 3, 1000)
var exit bool = false
var wg sync.WaitGroup
var tidchan = make(chan uint32)
var dbRoot string
func main() {
defer Loger.Flush()
addr := flag.String("addr", ":3307", "http service address")
config := flag.String("config", "./my.ini", "configuration file path")
flag.Parse()
// 读取 my.ini 获取 datadir
if err := readConfig(*config); err != nil {
return
}
go CreateTid()
wg.Add(1)
go StartServer(*addr)
for {
var cmd string
fmt.Scanf("%s", &cmd)
if cmd == "exit" {
exit = true
break
}
fmt.Println("未知命令")
fmt.Println("exit 退出程序")
}
wg.Wait()
}
// readConfig 读取配置文件
func readConfig(configPath string) error {
pz, err := config.ReadDefault(configPath)
if err != nil {
Loger.Error(err)
return err
}
if dbRoot, err = pz.String("mysqld", "datadir"); err != nil {
Loger.Error(err)
return err
}
return nil
}
// StartServer 启动服务器
func StartServer(addr string) {
defer wg.Done()
var netListen net.Listener
for !exit {
if netListen == nil {
var err error
if netListen, err = net.Listen("tcp", addr); err != nil {
Loger.Error(err)
} else {
go func() {
for !exit {
conn, err := netListen.Accept()
if err != nil {
continue
}
ch := ConnHandle{Conn: conn}
go ch.Start()
}
}()
}
} else {
time.Sleep(2 * time.Second)
}
}
if netListen != nil {
netListen.Close()
}
}
// CreateTid 生成事务ID
func CreateTid() {
tid := uint32(1)
for {
tidchan <- tid
tid++
if tid > 999999999 {
tid = 1
}
}
}
var ServerVersion = "5.5.15"
type ConnHandle struct {
Conn net.Conn
sequence uint8
Capabilities uint32
SchemaName string
CharacterSet uint8
User string
}
// Start 启动连接处理
func (ch *ConnHandle) Start() {
defer ch.Conn.Close()
salt, err := ch.WriteOnePack()
if err != nil {
Loger.Error(ch.Conn.RemoteAddr().String(), err)
return
}
b, err := ch.ReadOnePack()
if err != nil {
Loger.Error(ch.Conn.RemoteAddr().String(), err)
return
}
user, _, authResponse, err := ch.parseClientHandshakePacket(true, b)
if err != nil {
Loger.Errorf("无法解析来自 %s 的客户端握手响应: %v", ch.Conn, err)
return
}
if !isPassScrambleMysqlNativePassword(authResponse, salt, "root") {
ch.writeErrorPacket(1045, "00001", "用户名或密码错误")
return
}
ch.User = user
if err := ch.writeOKPacket(0, 0, 0, 0); err != nil {
Loger.Errorf("无法向 %s 写入OK包: %v", ch.Conn, err)
return
}
for !exit {
err := ch.handleNextCommand()
if err != nil {
Loger.Error(err)
return
}
}
}
// handleNextCommand 处理下一个命令
func (c *ConnHandle) handleNextCommand() error {
c.sequence = 0
data, err := c.readEphemeralPacket()
if err != nil {
return err
}
switch data[0] {
case ComQuit:
return errors.New("ComQuit")
case ComInitDB:
dbName := string(data[1:])
if _, err := os.Stat(filepath.Join(dbRoot, dbName)); os.IsNotExist(err) {
c.writeErrorPacket(1049, "42000", fmt.Sprintf("未知数据库 '%s'", dbName))
return nil
}
c.SchemaName = dbName
c.writeOKPacket(0, 0, 0, 0)
case ComQuery:
sql := string(data[1:])
sql = collapseSpaces(sql)
sql = strings.TrimRight(sql, ";")
sql = strings.TrimSpace(sql)
sql = strings.Join(strings.Fields(sql), " ")
// 处理自定义SQL命令: get mb xxx
if strings.HasPrefix(strings.ToUpper(sql), "GET MB") {
parts := strings.Fields(strings.ToUpper(sql))
var input string
if len(parts) > 2 {
input = strings.Join(parts[2:], " ")
} else {
input = "default"
}
// 返回输入的字符串
mbResult := input
rows := [][]string{
{"result", mbResult},
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "",
Table: "",
OriginalTable: "",
Name: "key",
OriginalName: "key",
Charset: 33,
Length: 50,
Type: 253,
Flags: 1,
Decimals: 0,
},
{
Catalog: "def",
Database: "",
Table: "",
OriginalTable: "",
Name: "value",
OriginalName: "value",
Charset: 33,
Length: 50,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else if strings.HasPrefix(strings.ToUpper(sql), "GET MD5") {
parts := strings.Fields(strings.ToUpper(sql))
var input string
if len(parts) > 2 {
input = strings.Join(parts[2:], " ")
} else {
input = "default"
}
// 计算MD5值
hash := md5.Sum([]byte(input))
md5Result := hex.EncodeToString(hash[:])
rows := [][]string{
{"result", md5Result},
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "",
Table: "",
OriginalTable: "",
Name: "key",
OriginalName: "key",
Charset: 33,
Length: 50,
Type: 253,
Flags: 1,
Decimals: 0,
},
{
Catalog: "def",
Database: "",
Table: "",
OriginalTable: "",
Name: "value",
OriginalName: "value",
Charset: 33,
Length: 50,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else if strings.HasPrefix(strings.ToUpper(sql), "SELECT COLLATIONS") {
rows := [][]string{}
for i := 0; i < 1000; i++ {
rows = append(rows, []string{strconv.Itoa(i)})
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "information_schema",
Table: "COLLATIONS",
OriginalTable: "COLLATIONS",
Name: "result",
OriginalName: "result",
Charset: 33,
Length: 96,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else if strings.HasPrefix(strings.ToUpper(sql), "SHOW DATABASES") {
dbs, err := os.ReadDir(dbRoot)
if err != nil {
c.writeErrorPacket(1049, "HY000", "无法读取数据库目录")
return nil
}
var rows [][]string
for _, entry := range dbs {
if entry.IsDir() {
rows = append(rows, []string{entry.Name()})
}
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "information_schema",
Table: "SCHEMATA",
OriginalTable: "SCHEMATA",
Name: "Database",
OriginalName: "SCHEMA_NAME",
Charset: 33,
Length: 192,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else if strings.HasPrefix(sql, "SELECT @@character_set_database, @@collation_database") {
rows := [][]string{
{"utf8mb4", "utf8mb4_general_ci"},
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "information_schema",
Table: "SCHEMATA",
OriginalTable: "SCHEMATA",
Name: "@@character_set_database",
OriginalName: "@@character_set_database",
Charset: 33,
Length: 192,
Type: 253,
Flags: 1,
Decimals: 0,
},
{
Catalog: "def",
Database: "information_schema",
Table: "SCHEMATA",
OriginalTable: "SCHEMATA",
Name: "@@collation_database",
OriginalName: "@@collation_database",
Charset: 33,
Length: 192,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else if strings.HasPrefix(sql, "SHOW FULL TABLES") {
if c.SchemaName == "" {
c.writeErrorPacket(1046, "3D000", "未选择数据库")
return nil
}
dbPath := filepath.Join(dbRoot, c.SchemaName)
entries, err := os.ReadDir(dbPath)
if err != nil {
c.writeErrorPacket(1049, "42000", fmt.Sprintf("Unknown database '%s'", c.SchemaName))
return nil
}
var rows [][]string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".mgx") {
tableName := strings.TrimSuffix(entry.Name(), ".mgx")
rows = append(rows, []string{tableName, "BASE TABLE"})
}
}
mr := MysqlResponse{
Fs: []FieldProtocol{
{
Catalog: "def",
Database: "information_schema",
Table: "SCHEMATA",
OriginalTable: "SCHEMATA",
Name: "Tables_in_" + c.SchemaName,
OriginalName: "Tables_in_" + c.SchemaName,
Charset: 33,
Length: 192,
Type: 253,
Flags: 1,
Decimals: 0,
},
{
Catalog: "def",
Database: "information_schema",
Table: "SCHEMATA",
OriginalTable: "SCHEMATA",
Name: "Table_type",
OriginalName: "Table_type",
Charset: 33,
Length: 192,
Type: 253,
Flags: 1,
Decimals: 0,
},
},
Rows: rows,
}
c.Conn.Write(mr.GetBytes())
} else {
c.writeOKPacket(0, 0, 0, 0)
}
case ComPing:
c.writeOKPacket(0, 0, 0, 0)
case ComSetOption:
default:
}
return nil
}
// collapseSpaces 压缩字符串中的空格
func collapseSpaces(input string) string {
var buf bytes.Buffer
reader := strings.NewReader(input)
prevWasSpace := false
for {
r, _, err := reader.ReadRune()
if err != nil {
break
}
if r == ' ' {
if !prevWasSpace {
buf.WriteRune(r)
}
prevWasSpace = true
} else {
buf.WriteRune(r)
prevWasSpace = false
}
}
return buf.String()
}
// writeOKPacket 写入OK数据包
func (ch *ConnHandle) writeOKPacket(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {
buf := bytes.NewBuffer([]byte{})
buf.WriteByte(OKPacket)
binary.Write(buf, binary.LittleEndian, affectedRows)
binary.Write(buf, binary.LittleEndian, lastInsertID)
binary.Write(buf, binary.LittleEndian, flags)
binary.Write(buf, binary.LittleEndian, warnings)
return ch.writePacket(buf.Bytes())
}
// writeErrorPacket 写入错误数据包
func (ch *ConnHandle) writeErrorPacket(errorCode uint16, sqlState, errMessage string) error {
buf := bytes.NewBuffer([]byte{})
buf.WriteByte(ErrPacket)
binary.Write(buf, binary.LittleEndian, errorCode)
buf.WriteByte('#')
buf.WriteString(sqlState)
buf.WriteString(errMessage)
return ch.writePacket(buf.Bytes())
}
// parseClientHandshakePacket 解析客户端握手数据包
func (ch *ConnHandle) parseClientHandshakePacket(firstTime bool, data []byte) (string, string, []byte, error) {
pos := 0
clientFlags, pos, ok := readUint32(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取客户端标志")
}
if clientFlags&CapabilityClientProtocol41 == 0 {
return "", "", nil, errors.New("parseClientHandshakePacket: 仅支持协议4.1")
}
if firstTime {
ch.Capabilities = clientFlags & (CapabilityClientDeprecateEOF | CapabilityClientFoundRows)
}
if clientFlags&CapabilityClientMultiStatements > 0 {
ch.Capabilities |= CapabilityClientMultiStatements
}
_, pos, ok = readUint32(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取maxPacketSize")
}
characterSet, pos, ok := readByte(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取characterSet")
}
ch.CharacterSet = characterSet
pos += 23
username, pos, ok := readNullString(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取username")
}
var authResponse []byte
// 只处理安全连接方式的身份验证响应
if clientFlags&CapabilityClientSecureConnection != 0 {
var l byte
l, pos, ok = readByte(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response长度")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")
}
} else {
a := ""
a, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取auth-response")
}
authResponse = []byte(a)
}
if clientFlags&CapabilityClientConnectWithDB != 0 {
dbname := ""
dbname, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取dbname")
}
ch.SchemaName = dbname
}
authMethod := MysqlNativePassword
if clientFlags&CapabilityClientPluginAuth != 0 {
authMethod, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, errors.New("parseClientHandshakePacket: 无法读取authMethod")
}
}
if authMethod == "" {
authMethod = MysqlNativePassword
}
if clientFlags&CapabilityClientConnAttr != 0 {
if _, _, err := parseConnAttrs(data, pos); err != nil {
Loger.Error("解码客户端发送的连接属性: ", err)
}
}
return username, authMethod, authResponse, nil
}
// ReadOnePack 读取一个数据包
func (ch *ConnHandle) ReadOnePack() ([]byte, error) {
var r io.Reader = ch.Conn.(io.Reader)
length, err := ch.readHeaderFrom(r)
if err != nil {
return nil, err
}
if length < MaxPacketSize {
buf := make([]byte, length)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
}
return buf, nil
}
return nil, errors.New("readEphemeralPacketDirect doesn't support more than one packet")
}
// WriteOnePack 写入一个数据包
func (ch *ConnHandle) WriteOnePack() ([]byte, error) {
salt, _ := NewSalt()
p01 := Packet01{
Ver: 10,
VerSion: ServerVersion,
ServerId: <-tidchan,
Salt: salt,
SerFlag1: []byte{255, 247},
Bm: 28,
SerType: 0,
SerFlag2: []byte{15, 128},
}
err := ch.writePacket(p01.GetBytes())
return salt, err
}
// writePacket 写入数据包
func (ch *ConnHandle) writePacket(data []byte) error {
index := 0
length := len(data)
w := ch.Conn.(*net.TCPConn)
for {
packetLength := length
if packetLength > MaxPacketSize {
packetLength = MaxPacketSize
}
var header [4]byte
header[0] = byte(packetLength)
header[1] = byte(packetLength >> 8)
header[2] = byte(packetLength >> 16)
header[3] = ch.sequence
if n, err := w.Write(header[:]); err != nil {
return errors.New("Write(header) failed")
} else if n != 4 {
return errors.New("Write(header) returned a short write: < 4")
}
if n, err := w.Write(data[index : index+packetLength]); err != nil {
return errors.New("Write(packet) failed")
} else if n != packetLength {
return errors.New("Write(packet) returned a short write")
}
ch.sequence++
length -= packetLength
if length == 0 {
if packetLength == MaxPacketSize {
header[0] = 0
header[1] = 0
header[2] = 0
header[3] = ch.sequence
if n, err := w.Write(header[:]); err != nil {
return errors.New("Write(empty header) failed")
} else if n != 4 {
return errors.New("Write(empty header) returned a short write")
}
ch.sequence++
}
return nil
}
index += packetLength
}
}
// readHeaderFrom 从读取器中读取头部
func (ch *ConnHandle) readHeaderFrom(r io.Reader) (int, error) {
var header [4]byte
if _, err := io.ReadFull(r, header[:]); err != nil {
if err == io.EOF {
return 0, err
}
if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
return 0, io.EOF
}
return 0, errors.New("io.ReadFull(header size) failed")
}
sequence := uint8(header[3])
if sequence != ch.sequence {
return 0, errors.New("invalid sequence: expected " + strconv.Itoa(int(ch.sequence)) + ", got " + strconv.Itoa(int(sequence)))
}
ch.sequence++
return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16), nil
}
// readEphemeralPacket 读取临时数据包
func (c *ConnHandle) readEphemeralPacket() ([]byte, error) {
var r io.Reader = c.Conn.(io.Reader)
length, err := c.readHeaderFrom(r)
if err != nil {
return nil, err
}
if length == 0 {
return nil, nil
}
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
}
if length < MaxPacketSize {
return data, nil
}
for {
next, err := c.readOnePacket()
if err != nil {
return nil, err
}
if len(next) == 0 {
break
}
data = append(data, next...)
if len(next) < MaxPacketSize {
break
}
}
return data, nil
}
// readOnePacket 读取一个数据包
func (ch *ConnHandle) readOnePacket() ([]byte, error) {
var r io.Reader = ch.Conn.(io.Reader)
length, err := ch.readHeaderFrom(r)
if err != nil {
return nil, err
}
if length == 0 {
return nil, nil
}
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.New("io.ReadFull(packet body of length " + strconv.Itoa(length) + ") failed")
}
return data, nil
}
// parseConnAttrs 解析连接属性
func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
var attrLen uint64
attrLen, pos, ok := readLenEncInt(data, pos)
if !ok {
return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性变量长度")
}
var attrLenRead uint64
attrs := make(map[string]string)
for attrLenRead < attrLen {
var keyLen byte
keyLen, pos, ok = readByte(data, pos)
if !ok {
return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键长度")
}
attrLenRead += uint64(keyLen) + 1
var connAttrKey []byte
connAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen))
if !ok {
return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性键")
}
var valLen byte
valLen, pos, ok = readByte(data, pos)
if !ok {
return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值长度")
}
attrLenRead += uint64(valLen) + 1
var connAttrVal []byte
connAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen))
if !ok {
return nil, 0, errors.New("parseClientHandshakePacket: 无法读取连接属性值")
}
attrs[string(connAttrKey[:])] = string(connAttrVal[:])
}
return attrs, pos, nil
}
// isPassScrambleMysqlNativePassword 验证密码是否匹配
func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword string) bool {
if len(reply) == 0 {
return false
}
if mysqlNativePassword == "" {
return false
}
mysqlNativePassword = NativePassword(mysqlNativePassword)
if strings.Contains(mysqlNativePassword, "*") {
mysqlNativePassword = mysqlNativePassword[1:]
}
hash, err := hex.DecodeString(mysqlNativePassword)
if err != nil {
return false
}
crypt := sha1.New()
crypt.Write(salt)
crypt.Write(hash)
scramble := crypt.Sum(nil)
for i := range scramble {
scramble[i] ^= reply[i]
}
hashStage1 := scramble
crypt.Reset()
crypt.Write(hashStage1)
candidateHash2 := crypt.Sum(nil)
return bytes.Equal(candidateHash2, hash)
}
// NativePassword 生成原生密码格式
func NativePassword(password string) string {
if len(password) == 0 {
return ""
}
hash := sha1.New()
hash.Write([]byte(password))
s1 := hash.Sum(nil)
hash.Reset()
hash.Write(s1)
s2 := hash.Sum(nil)
s := strings.ToUpper(hex.EncodeToString(s2))
return fmt.Sprintf("*%s", s)
}
// NewSalt 生成新的盐值
func NewSalt() ([]byte, error) {
salt := make([]byte, 20)
if _, err := rand.Read(salt); err != nil {
return nil, err
}
for i := 0; i < len(salt); i++ {
salt[i] &= 0x7f
if salt[i] == '\x00' || salt[i] == '$' {
salt[i]++
}
}
return salt, nil
}
// Packet01 数据包结构
type Packet01 struct {
Ver byte
VerSion string
ServerId uint32
Salt []byte
SerFlag1 []byte
Bm int8
SerType int16
SerFlag2 []byte
}
// GetBytes 获取数据包的字节表示
func (p01 *Packet01) GetBytes() []byte {
b := bytes.NewBuffer([]byte{})
b.WriteByte(p01.Ver)
b.WriteString(p01.VerSion)
b.WriteByte(byte(0))
binary.Write(b, binary.BigEndian, p01.ServerId)
b.Write(p01.Salt[:8])
b.WriteByte(byte(0))
binary.Write(b, binary.BigEndian, p01.SerFlag1)
b.WriteByte(byte(p01.Bm))
binary.Write(b, binary.BigEndian, p01.SerType)
binary.Write(b, binary.BigEndian, p01.SerFlag2)
b.WriteByte(byte(21))
b.Write(bytes.Repeat([]byte{0}, 10))
b.Write(p01.Salt[8:])
b.WriteByte(byte(0))
b.WriteString(MysqlNativePassword)
b.WriteByte(byte(0))
return b.Bytes()
}
// MysqlResponse MySQL响应结构
type MysqlResponse struct {
Fs []FieldProtocol
Rows [][]string
}
// FieldProtocol 字段协议结构
type FieldProtocol struct {
Catalog string
Database string
Table string
OriginalTable string
Name string
OriginalName string
Charset int
Length int
Type int
Flags int
Decimals int
}
// GetBytes 获取字段协议的字节表示
func (fp *FieldProtocol) GetBytes(pk int) []byte {
b := bytes.NewBuffer([]byte{0, 0, 0, 0})
lt := EncodeLength(len(fp.Catalog))
b.Write(lt)
b.WriteString(fp.Catalog)
lt = EncodeLength(len(fp.Database))
b.Write(lt)
b.WriteString(fp.Database)
lt = EncodeLength(len(fp.Table))
b.Write(lt)
b.WriteString(fp.Table)
lt = EncodeLength(len(fp.OriginalTable))
b.Write(lt)
b.WriteString(fp.OriginalTable)
lt = EncodeLength(len(fp.Name))
b.Write(lt)
b.WriteString(fp.Name)
lt = EncodeLength(len(fp.OriginalName))
b.Write(lt)
b.WriteString(fp.OriginalName)
b.WriteByte(0x0c)
b.WriteByte(byte(fp.Charset))
b.WriteByte(0x00)
b.Write(intToBytesLittleEndian(fp.Length))
b.WriteByte(byte(fp.Type))
lt = intToBytesLittleEndian(fp.Flags)
b.WriteByte(lt[0])
b.WriteByte(lt[1])
b.WriteByte(byte(fp.Decimals))
b.WriteByte(0x00)
b.WriteByte(0x00)
bs := b.Bytes()
lb := intToBytesLittleEndian(len(bs) - 4)
bs[0] = byte(lb[0])
bs[1] = byte(lb[1])
bs[2] = byte(lb[2])
bs[3] = byte(pk)
return bs
}
// GetBytes 获取MySQL响应的字节表示
func (mr *MysqlResponse) GetBytes() []byte {
pk := 1
b := bytes.NewBuffer([]byte{})
numberFields := EncodeLength(len(mr.Fs))
numberPackets := intToBytesLittleEndian(len(numberFields))
numberPackets[3] = byte(pk)
b.Write(numberPackets)
b.Write(numberFields)
for _, f := range mr.Fs {
pk++
if pk > 255 {
pk = 0
}
b.Write(f.GetBytes(pk))
}
pk++
if pk > 255 {
pk = 0
}
b.Write(getEOFPacket(pk))
for _, r := range mr.Rows {
pk++
if pk > 255 {
pk = 0
}
b.Write(getRowPacket(r, pk))
}
pk++
if pk > 255 {
pk = 0
}
b.Write(getEOFPacket(pk))
return b.Bytes()
}
// getRowPacket 获取行数据包
func getRowPacket(strs []string, pk int) []byte {
b := bytes.NewBuffer([]byte{0, 0, 0, 0})
for _, s := range strs {
lt := EncodeLength(len(s))
b.Write(lt)
b.WriteString(s)
}
bs := b.Bytes()
lb := intToBytesLittleEndian(len(bs) - 4)
bs[0] = byte(lb[0])
bs[1] = byte(lb[1])
bs[2] = byte(lb[2])
bs[3] = byte(pk)
return bs
}
// getEOFPacket 获取EOF数据包
func getEOFPacket(pk int) []byte {
return []byte{0x05, 0x00, 0x00, byte(pk), 0xFE, 0x00, 0x00, 0x22, 0x00}
}
// EncodeLength 编码长度
func EncodeLength(length int) []byte {
var buf bytes.Buffer
if length < 251 {
buf.WriteByte(byte(length))
} else if length <= 65535 {
buf.WriteByte(0xfd)
binary.Write(&buf, binary.LittleEndian, uint16(length))
} else {
buf.WriteByte(0xfe)
binary.Write(&buf, binary.LittleEndian, uint64(length))
}
return buf.Bytes()
}
// intToBytesLittleEndian 将整数转换为小端字节序
func intToBytesLittleEndian(num int) []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint32(num))
return buf.Bytes()
}
// readByte 读取一个字节
func readByte(data []byte, pos int) (byte, int, bool) {
if pos >= len(data) {
return 0, 0, false
}
return data[pos], pos + 1, true
}
// readBytesCopy 读取并复制字节
func readBytesCopy(data []byte, pos int, size int) ([]byte, int, bool) {
if pos+size > len(data) {
return nil, 0, false
}
result := make([]byte, size)
copy(result, data[pos:pos+size])
return result, pos + size, true
}
// readNullString 读取以null结尾的字符串
func readNullString(data []byte, pos int) (string, int, bool) {
end := bytes.IndexByte(data[pos:], 0)
if end == -1 {
return "", 0, false
}
return string(data[pos : pos+end]), pos + end + 1, true
}
// readUint32 读取uint32
func readUint32(data []byte, pos int) (uint32, int, bool) {
if pos+4 > len(data) {
return 0, 0, false
}
return binary.LittleEndian.Uint32(data[pos : pos+4]), pos + 4, true
}
// readLenEncInt 读取长度编码的整数
func readLenEncInt(data []byte, pos int) (uint64, int, bool) {
if pos >= len(data) {
return 0, 0, false
}
switch data[pos] {
case 0xfc:
if pos+3 > len(data) {
return 0, 0, false
}
return uint64(data[pos+1]) | uint64(data[pos+2])<<8, pos + 3, true
case 0xfd:
if pos+4 > len(data) {
return 0, 0, false
}
return uint64(data[pos+1]) |
uint64(data[pos+2])<<8 |
uint64(data[pos+3])<<16, pos + 4, true
case 0xfe:
if pos+9 > len(data) {
return 0, 0, false
}
return uint64(data[pos+1]) |
uint64(data[pos+2])<<8 |
uint64(data[pos+3])<<16 |
uint64(data[pos+4])<<24 |
uint64(data[pos+5])<<32 |
uint64(data[pos+6])<<40 |
uint64(data[pos+7])<<48 |
uint64(data[pos+8])<<56, pos + 9, true
}
return uint64(data[pos]), pos + 1, true
}
写在最后:当程序员被逼急了...
这个项目让我深刻体会到:程序员的创造力往往来自于解决同事的"奇怪需求"。谁能想到,一个易语言的线程安全问题,最终会催生一个MySQL协议服务器?
所以,下次如果你的同事提出一个看似不合理的需求,不妨换个角度想想——也许这正是你提升技能、探索未知领域的好机会!
最后,如果你也有类似的"被逼无奈"的开发经历,欢迎在评论区分享,让我们一起吐槽,不对,让我们一起学习!
P.S. 同事的易语言项目现在运行得很稳定,他对我说:"还有很多地方可以这样优化,帮我再写几个接口吧!" 我沉默不语......
其他部分文章列表(注意:来源于本人csdn原创)
- 用 Go 手搓一个内网 DNS 服务器:从此告别 IP 地址,用域名畅游家庭网络!
- 我用Go写了个华容道游戏,曹操终于不用再求关羽了!
- 用 Go 接口把 Excel 变成数据库:一个疯狂但可行的想法
- 穿墙术大揭秘:用 Go 手搓一个"内网穿透"神器!
- 布隆过滤器(go):一个可能犯错但从不撒谎的内存大师
- 自由通讯的魔法:Go从零实现UDP/P2P 聊天工具
- Go语言实现的简易远程传屏工具:让你的屏幕「飞」起来
- 当你的程序学会了"诈尸":Go 实现 Windows 进程守护术
- 验证码识别API:告别收费接口,迎接免费午餐
- 用 Go 给 Windows 装个"顺风耳":两分钟写个录音小工具
- 使用 Go + govcl 实现 Windows 资源管理器快捷方式管理器
- 用 Go 手搓一个 NTP 服务:从"时间混乱"到"精准同步"的奇幻之旅
- 用 Go 手搓一个 Java 构建工具:当 IDE 不在身边时的自救指南
- 深入理解 Windows 全局键盘钩子(Hook):拦截 Win 键的 Go 实现
- 用 Go 语言实现《周易》大衍筮法起卦程序
- Go 语言400行代码实现 INI 配置文件解析器:支持注释、转义与类型推断
- 高性能 Go 语言带 TTL 的内存缓存实现:精确过期、自动刷新、并发安全
- Golang + OpenSSL 实现 TLS 安全通信:从私有 CA 到动态证书加载