用 go 实现内存数据库

269 阅读13分钟

实现 dict

我们先来定义 redis 内部的数据结构 dictdict 是一个 key-value 的数据结构,key 是一个字符串,value 是一个 interface{}

dict 会有很多方法,比如 GetPut 等,我们先来定义 dict 的接口

type Consumer func(key string, val interface{}) bool
type Dict interface {
  Get(key string) (val interface{}, exists bool)
  Put(key string, val interface{}) (result int)         // 存进去回复 1,没存进去回复 0
  PutIfAbsent(key string, val interface{}) (result int) // 如果 key 不存在,才能设置,如果 key 存在回复 0,不存在回复 1
  PutIfExists(key string, val interface{}) (result int) // 如果 key 存在,才能设置,如果 key 存在回复 1,不存在回复 0
  Remove(key string) (result int)
  ForEach(consumer Consumer)             // 遍历所有键值对
  Keys() []string                        // 获取所有的 key
  RandomKeys(limit int) []string         // 返回 n 个 key
  RandomDistinctKeys(limit int) []string // 返回 n 个不重复的 key
  Len() int
  Clear() // 清空数据结构
}

对于 Dict 的实现,我们需要用并发安全的 map 去实现,也就是 sync.Map

type SyncDict struct {
  m sync.Map
}
func MakeSyncDict() *SyncDict {
  return &SyncDict{}
}

然后实现 Dict 接口

Get

sync.Map 中获取值,sync.Map 获取值的方法是 Load

func (dict *SyncDict) Get(key string) (val interface{}, exists bool) {
  val, ok := dict.m.Load(key)
  return val, ok
}

Len

通过 Range 遍历 Sync.Map 中的所有键值对,累加遍历的次数,得到 dict 的长度

func (dict *SyncDict) Len() int {
  length := 0
  dict.m.Range(func(key, value any) bool {
    length++
    return true
  })
  return length
}

Put

Put 方法是往 sync.Map 中存值,sync.Map 存值的方法是 Store,如果设置成功,返回 1,否则返回 0

func (dict *SyncDict) Put(key string, val interface{}) (result int) {
  _, existed := dict.m.Load(key)
  dict.m.Store(key, val)
  if existed {
    return 0
  }
  return 1
}

PutIfAbsent

PutIfAbsent 方法是如果 key 不存在,才能设置,如果 key 存在回复 0,不存在回复 1

func (dict *SyncDict) PutIfAbsent(key string, val interface{}) (result int) {
  _, existed := dict.m.Load(key)
  if existed {
    return 0
  }
  dict.m.Store(key, val)
  return 1
}

PutIfExist

PutIfExist 方法是如果 key 存在,才能设置,如果 key 存在回复 1,不存在回复 0

func (dict *SyncDict) PutIfExists(key string, val interface{}) (result int) {
  _, existed := dict.m.Load(key)
  if existed {
    dict.m.Store(key, val)
    return 1
  }
  return 0
}

Remove

Remove 方法是删除 key,如果 key 存在,返回 1,否则返回 0sync.Map 删除值的方法是 Delete

func (dict *SyncDict) Remove(key string) (result int) {
  _, existed := dict.m.Load(key)
  dict.m.Delete(key)
  if existed {
    return 1
  }
  return 0
}

ForEach

ForEach 方法是遍历所有键值对,sync.Map 的遍历方法是 Range,遍历的时候调用 Consumer 函数

func (dict *SyncDict) ForEach(consumer Consumer) {
  dict.m.Range(func(key, value any) bool {
    consumer(key.(string), value)
    return true
  })
}

Keys

Keys 方法是获取所有的 key,刚开始初始化时,通过 dict.Len() 设置切片的长度,dict.Len 方法就是上面实现的 Len 方法

func (dict *SyncDict) Keys() []string {
  // 用 dict.Len() 设置切片的长度
  result := make([]string, dict.Len())
  i := 0
  dict.m.Range(func(key, value any) bool {
    // 遍历时,将 key 转为 string 类型,存入切片
    result[i] = key.(string)
    i++
    return true
  })
  return result
}

RandomKeys

RandomKeys 方法接收一个参数 limit,返回 limitkey

func (dict *SyncDict) RandomKeys(limit int) []string {
  // 用 limit 设置切片的长度
  result := make([]string, dict.Len())
  // 遍历 limit
  for i := 0; i < limit; i++ {
    // 通过随机数获取 key
    dict.m.Range(func(key, value any) bool {
      result[i] = key.(string)
      // return false 的作用是只遍历一次,这样就保证一次 for 循环时,随机获取其中一个 key
      return false
    })
  }
  return result
}

RandomDistinctKeys

RandomDistinctKeys 方法接收一个参数 limit,返回 limit 个不重复的 key

goforEach 是无序的,所以需要在外面用 i 记录下遍历了多次次,如果 i 等于 limit,就返回 false,否则一直遍历下去

func (dict *SyncDict) RandomDistinctKeys(limit int) []string {
  result := make([]string, dict.Len())
  i := 0
  dict.m.Range(func(key, value any) bool {
    result[i] = key.(string)
    i++
    // 如果 i 等于 limit 说明遍历了 limit 次,就返回 false,结束遍历
    if i == limit {
      return false
    }
    return true
  })
  return result
}

Clear

Clear 方法是清空数据结构,我们可以直接将 dict 重新赋值为一个新的 dict

这里要注意的是要用 *dict,用 dict 不会修改原来的 dict

func (dict *SyncDict) Clear() {
  // dict = MakeSyncDict(),这里相当于新建了一个 dict 指针,原来的 dict 不会有任何变化
  *dict = *MakeSyncDict()
}

实现 command

dictredis 最底层的数据结构,它的上一层是 dbdb 就是一个分数据库(redis 一共有 16 个数据库)

type DB struct {
  index  int       // 数据的编号
  data   dict.Dict // 数据类型
}
func makeDB() *DB {
  return &DB{
    data:   dict.MakeSyncDict(),
  }
}

声明两数据类型,第一个类型是一个方法 ExecFunc,它的作用是 redis 中所有的指令的类型

CmdLine 是一个二维的字节切片

type ExecFunc func(db *DB, args [][]byte) resp.Reply
type CmdLine = [][]byte

每一个 command 会有一个执行方法,也就是说每一个指令,比如 PingGET 都有一个 command 的结构体,这个结构体里面都有一个 exector 的方法

外面就实现 exector 这个方法,然后施加到 db

arity 是参数的数量

type command struct {
	exector ExecFunc // 执行方法
	arity   int      // 参数数量
}

然后在定义一个 cmdTable,它的类型就是一个 commandmap,用来存放所有的 command

var cmdTable = make(map[string]*command)

然后在实现一个注册指令的方法

func RegisterCommand(name string, exector ExecFunc, arity int) {
  name = strings.ToLower(name)
  cmdTable[name] = &command{exector: exector, arity: arity}
}

实现 DB.Exec

Exec 方法是执行指令的方法,它接收一个 CmdLine 参数,CmdLine 就是上面定义的 [][]byte

func (db *DB) Exec(c resp.Connection, cmdLine CmdLine) resp.Reply {
  // 拿到指令的名字,比如 GET、SET
  cmdName := strings.ToLower(string(cmdLine[0]))
  // 从 cmdTable 中拿到指令的结构体
  cmd, ok := cmdTable[cmdName]
  // 如果指令不存在,返回错误
  if !ok {
    return reply.MakeErrReply("ERR unknown command " + cmdName)
  }
  // 验证参数的数量,如果参数数量不对,返回错误
  if !validateArity(cmd.arity, cmdLine) {
    return reply.MakeArgNumErrReply(cmdName)
  }
  // 拿到指令的执行方法
  fun := cmd.exector
  // 执行指令,传入 db 和指令的参数
  // SET key value,cmdLine[1:] 就是 key 和 value
  return fun(db, cmdLine[1:])
}

validateArity

validateArity 方法是验证指令传入的参数数量是否正确

redis 的指令有两种情况

  • 固定参数,比如 SET Key Value,所以参数数量必须等于 arity
  • 不固定参数,比如 EXISTS k1 k2 ...,这种参数不固定的,arity 会赋值 -2,所以参数数量必须大于等于 -arity
func validateArity(arity int, cmdArgs [][]byte) bool {
  argNum := len(cmdArgs)
  // 固定参数,比如 SET Key Value,所以参数数量必须等于 arity
  if arity > 0 {
    return argNum == arity
  }
  // 不固定参数,比如 EXISTS k1 k2 ...,这种参数不固定的,arity 会赋值 -2,所以参数数量必须大于等于 -arity
  return argNum >= -arity
}

GetEntity

这是一个 db 的方法,是对 dict 的封装

GetEntity 方法是从 db 中通过 key 获取对应的 value

func (db *DB) GetEntity(key string) (*databaseface.DataEntity, bool) {
  // 调用 dict 的 Get 方法,通过 key 获取 value
  raw, ok := db.data.Get(key)
  if !ok {
    return nil, false
  }
  // 底层 dict 存的是 interface{},所以需要转为 *DataEntity
  entity, _ := raw.(*databaseface.DataEntity)
  return entity, true
}

PutEntity

PutEntity 方法是往 db 中存入 keyvaluedictPut 也是返回 10,所以这里也是返回 10

func (db *DB) PutEntity(key string, entity *databaseface.DataEntity) int {
  return db.data.Put(key, entity)
}

PutIfExists

PutIfExists 方法是如果 key 存在,才能设置,如果 key 存在回复 1,不存在回复 0

func (db *DB) PutIfExists(key string, entity *databaseface.DataEntity) int {
  return db.data.PutIfExists(key, entity)
}

PutIfAbsent

PutIfAbsent 方法是如果 key 不存在,才能设置,如果 key 存在回复 0,不存在回复 1

func (db *DB) PutIfAbsent(key string, entity *databaseface.DataEntity) int {
  return db.data.PutIfAbsent(key, entity)
}

Remove

Remove 方法是删除 key,如果 key 存在,返回 1,否则返回 0

func (db *DB) Remove(key string) int {
  return db.data.Remove(key)
}

Removes

Removes 是删除一组的key

通过遍历 keys,调用上面实现的 Remove 方法,删除每个 key

func (db *DB) Removes(keys ...string) (deleted int) {
  deleted = 0
  // 遍历 keys,删除每一个 key
  for _, key := range keys {
    // 先判断 key 是否存在,存在就删除,并且 deleted++
    _, exists := db.data.Get(key)
    if exists {
      db.Remove(key)
      deleted++
    }
  }
  // 返回删除的数量
  return deleted
}

Flush

Flush 方法是清空数据结构,调用 dict.Clear 方法可以实现清空数据库的操作

func (db *DB) Flush() {
  db.data.Clear()
}

实现 Ping 指令

Ping 指令是 redis 中最简单的指令,只是返回一个 PONG,所以我们先来实现 Ping 指令

Ping 指令是比较简单的,直接调用 reply.MakePongReply 方法,返回一个 PONG 的回复

func Ping(db *DB, args [][]byte) resp.Reply {
  return reply.MakePongReply()
}

在这个文件中定义一个 init 方法,文件加载时会自动执行这个方法

这个方法会注册 Ping 指令

func init() {
  RegisterCommand("ping", Ping, 1) // 注册 Ping 指令,参数数量 arity 是 1
}

实现 Keys 相关指令

Keys 相关指令有:DELEXISTSKEYSFLUSHDBRENAMERENAMENX 这几个

DEL

DEL 指令可以删除一堆 key,所以参数数量是不固定的,arity-2

// DEL K1 K2 K3
func DEL(db *DB, args [][]byte) resp.Reply {
  keys := make([]string, len(args))
  for i, v := range args {
    keys[i] = string(v)
  }
  deleted := db.Removes(keys...)
  return reply.MakeIntReply(int64(deleted))
}

然后将 DEL 指令注册到 cmdTable 中,arity-2,代表参数数量是不固定的

RegisterCommand("DEL", DEL, -2)

EXISTS

EXISTS 指令是判断 key 是否存在,参数数量是不固定的,arity-2

// EXISTS K1 K2 K3 ...
func EXISTS(db *DB, args [][]byte) resp.Reply {
  result := int64(0)
  for _, arg := range args {
    key := string(arg)
    _, exists := db.GetEntity(key)
    if exists {
      result++
    }
  }
  return reply.MakeIntReply(result)
}

然后将 EXISTS 指令注册到 cmdTable 中,arity-2,代表参数数量是不固定的

RegisterCommand("EXISTS", EXISTS, -2)

FLUSHDB

FLUSHDB 指令是清空数据库,后面不需要参数,如果有参数,直接忽略

func FLUSHDB(db *DB, args [][]byte) resp.Reply {
	db.Flush()
	return reply.MakeOkReply()
}

然后将 FLUSHDB 指令注册到 cmdTable 中,忽略后面的参数,所以 arity-1

RegisterCommand("FLUSHDB", FLUSHDB, -1)

TYPE

TYPE 指令在 redis 中是获取一个 key 的数据类型,这里我们只有 string 类型,所以直接返回 string

// TYPE K1
func TYPE(db *DB, args [][]byte) resp.Reply {
  key := string(args[0])
  entity, exists := db.GetEntity(key)
  // 如果 key 不存在,返回 none
  if !exists {
    return reply.MakeStatusReply("none") // TCP :none\r\n
  }
  switch entity.Data.(type) {
  // 如果是 []byte 类型,返回 string
  case []byte:
    return reply.MakeStatusReply("string")
  }
  return reply.MakeUnknownErrReply()
}

然后将 TYPE 指令注册到 cmdTable 中,arity2

RegisterCommand("TYPE", TYPE, 2)  // TYPE K1

RENAME

RENAME 指令是重命名 keyRENAME K1 K2,将 K1 重命名为 K2,如何设置成功,返回 OK,否则返回错误

// RENAME K1 K2
func RENAME(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  src := string(args[0])
  // 从用户输入的命令中拿到 K2
  dest := string(args[1])
  // 通过 K1 获取 entity
  entity, exists := db.GetEntity(src)
  // 如果 K1 不存在,返回错误
  if !exists {
    return reply.MakeErrReply("no such key")
  }
  // 将 K2 的值设置为 K1 的 entity
  db.PutEntity(dest, entity)
  // 删除 K1
  db.Remove(src)
  return reply.MakeOkReply()
}

然后将 RENAME 指令注册到 cmdTable 中,arity3

RegisterCommand("RENAME", RENAME, 3)  // RENAME K1 K2

RENAMENX

RENAMENX 指令是重命名 keyRENAMENX K1 K2,将 K1 重命名为 K2

RENAMENXRENAME 的区别是,如果 K2 存在,RENAMENX 不会重命名(返回 0,设置成功了返回 1),RENAME 会覆盖

// RENAMENX K1 K2
func RENAMENX(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  src := string(args[0])
  // 从用户输入的命令中拿到 K2
  dest := string(args[1])
  _, ok := db.GetEntity(dest)
  // 如果 K2 存在,啥也不干
  if ok {
    return reply.MakeIntReply(0)
  }
  entity, exists := db.GetEntity(src)
  // 如果 K1 不存在,返回错误
  if !exists {
    return reply.MakeErrReply("no such key")
  }
  // 将 K2 的值设置为 K1 的 entity
  db.PutEntity(dest, entity)
  // 删除 K1
  db.Remove(src)
  return reply.MakeIntReply(1)
}

然后将 RENAMENX 指令注册到 cmdTable 中,arity3

RegisterCommand("RENAMENX", RENAMENX, 3)  // RENAMENX K1 K2

KEYS

KEYS 指令是获取所有的 keyKEYS *,返回所有的 key

KEYS 指令是通过 wildcard 包来实现的,wildcard 包是一个通配符的包,可以通过 * 来匹配所有的 key

// KEYS *
func KEYS(db *DB, args [][]byte) resp.Reply {
  // 通过 wildcard 包来实现通配符
  pattern := wildcard.CompilePattern(string(args[0]))
  // 存放所有通配符的 key
  result := make([][]byte, 0)
  db.data.ForEach(func(key string, val interface{}) bool {
    // 如果 key 匹配通配符,就存入 result
    if pattern.IsMatch(key) {
      // 将 key 转为 []byte 类型,存入 result
      result = append(result, []byte(key))
    }
    // 返回 true,继续遍历
    return true
  })
  // 返回所有匹配的 key
  return reply.MakeMultiBulkReply(result)
}

然后将 KEYS 指令注册到 cmdTable 中,arity2

RegisterCommand("KEYS", KEYS, 2)  // KEYS *

实现 String 相关指令

String 相关指令有:GETSETSETNXGETSETSTRLEN 这几个

GET

GET 指令是获取 key 的值,GET K1,返回 K1 的值

// GET K1
func GET(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  key := string(args[0])
  // 通过 K1 获取 entity
  entity, exists := db.GetEntity(key)
  // 如果 K1 不存在,返回 nil
  if !exists {
    return reply.MakeNullBulkReply()
  }
  bytes := entity.Data.([]byte)
  // 将 data 包装成 MakeBulkReply
  return reply.MakeBulkReply(bytes)
}

然后将 GET 指令注册到 cmdTable 中,arity2

RegisterCommand("GET", GET, 2)

SET

SET 指令是设置 key 的值,SET K1 V1,将 K1 的值设置为 V1

// SET K1 v
func SET(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  key := string(args[0])
  // 从用户输入的命令中拿到 v
  value := args[1]
  // 将 v 存入 entity
  entity := &databaseface.DataEntity{Data: value}
  // 将 K1 和 entity 存入 db
  db.PutEntity(key, entity)
  // 返回 OK
  return reply.MakeOkReply()
}

然后将 SET 指令注册到 cmdTable 中,arity3

RegisterCommand("SET", SET, 3)

SETNX

SETNX 指令和 SET 指令的区别是,SETNX 指令是如果 key 不存在,才能设置,如果 key 存在回复 0,不存在回复 1

// SETNX K1 v
func SETNX(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  key := string(args[0])
  // 从用户输入的命令中拿到 v
  value := args[1]
  // 将 v 存入 entity
  entity := &databaseface.DataEntity{Data: value}
  // 调用 PutIfAbsent,将 K1 和 entity 存入 db,如果 key 不存在,返回 1,存在返回 0
  result := db.PutIfAbsent(key, entity)
  return reply.MakeIntReply(int64(result))
}

然后将 SETNX 指令注册到 cmdTable 中,arity3

RegisterCommand("SETNX", SETNX, 3)

GETSET

GETSET 指令是设置 key 的值,并且返回 key 的旧值,GETSET K1 V1,将 K1 的值设置为 V1,返回 K1 的旧值

// GETSET K1 v1
func GETSET(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K1
  key := string(args[0])
  // 从用户输入的命令中拿到 v1
  value := args[1]
  // 通过 K1 获取 entity
  entity, exists := db.GetEntity(key)
  // 将 v1 转为 []byte 存入 db 中
  db.PutEntity(key, &databaseface.DataEntity{Data: value})
  // 如果 K1 不存在,返回 nil
  if !exists {
    return reply.MakeNullBulkReply()
  }
  // 返回 K1 的旧值
  return reply.MakeBulkReply(entity.Data.([]byte))
}

然后将 GETSET 指令注册到 cmdTable 中,arity3

RegisterCommand("GETSET", GETSET, 3)

STRLEN

STRLEN 指令是获取 key 的长度,STRLEN K1,返回 K1 的长度

// STRLEN K
func STRLEN(db *DB, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 K
  key := string(args[0])
  // 通过 K 获取 entity
  entity, exists := db.GetEntity(key)
  // 如果 K 不存在,返回 null
  if !exists {
    return reply.MakeNullBulkReply()
  }
  bytes := entity.Data.([]byte)
  // 返回 K 的长度
  return reply.MakeIntReply(int64(len(bytes)))
}

然后将 STRLEN 指令注册到 cmdTable 中,arity2

RegisterCommand("STRLEN", STRLEN, 2)

实现核心 Database

redis 基本指令都实现了之后,还差一个最核心的内核

我们 redis 当成一个数据库,它内部有 16 个分数据库,每个数据库都是一个 DB,所以我们需要实现一个 Database,它是一个 DB 的集合

type StandaloneDatabase struct {
  dbSet []*DB
}

这个结构体需要实现 Database 接口:

type Database interface {
	Exec(client resp.Connection, args [][]byte) resp.Reply // 执行指令:执行的客户端 client,执行的指令 args,返回的是 Reply
	Close()                                                // 关闭数据库
	AfterClientClose(c resp.Connection)                    // 数据库关闭之后的善后工作
}

这个接口最核心的是 Exec 方法,它接收一个 clientargs,返回一个 Reply

这个方法主要的作用就是调用分数据库的 db.Exec 方法,然后返回结果

func (database *StandaloneDatabase) Exec(client resp.Connection, args [][]byte) resp.Reply {
  // 用来 recover panic
  defer func() {
    if err := recover(); err != nil {
      logger.Error(err)
    }
  }()
  // 拿到第一个参数,也就是指令的名字
  cmdName := strings.ToLower(string(args[0]))
  // 如果指令是 select,切换数据库
  if cmdName == "select" {
    if len(args) != 2 {
      return reply.MakeArgNumErrReply("select")
    }
    execSelect(client, database, args[1:])
  }
  // 从 client 中拿到 dbIndex
  dbIndex := client.GetDBIndex()
  // 通过 dbIndex 拿到 db
  db := database.dbSet[dbIndex]
  // 执行指令
  return db.Exec(client, args)
}

其中 execSelect 方法是切换数据库的方法

// select 2
func execSelect(c resp.Connection, database *StandaloneDatabase, args [][]byte) resp.Reply {
  // 从用户输入的命令中拿到 dbIndex
  dbIndex, err := strconv.Atoi(string(args[0]))
  // 如果 dbIndex 不是数字,返回错误
  if err != nil {
    return reply.MakeErrReply("ERR invalid DB index")
  }
  // 如果 dbIndex 超出了数据库的范围,返回错误
  if dbIndex >= len(database.dbSet) {
    return reply.MakeErrReply("ERR db index is out of range")
  }
  // 设置 dbIndex
  c.SelectDB(dbIndex)
  // 返回 OK
  return reply.MakeOkReply()
}

然后在实现 StandAloneDatabaseNew 方法

func NewStandaloneDatabase() *StandaloneDatabase {
  database := &StandaloneDatabase{}
  if config.Properties.Databases == 0 {
    config.Properties.Databases = 16
  }
  // 初始化 16 个数据库
  database.dbSet = make([]*DB, config.Properties.Databases)
  // 初始化每一个数据库
  for i := range database.dbSet {
    db := makeDB()
    db.index = i
    database.dbSet[i] = db
  }
  return database
}

测试

运行 main.go,然后用 nc 连接

echo -ne '*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n' | nc 127.0.0.1 3000

源码

  1. command
  2. standalone_database
  3. db