100 行代码实现具有防止缓存击穿的缓存层数据请求的库

105 阅读3分钟

需求

日常开发业务中,优先从缓存数据库中获取数据,如果缓存中没有找到数据就从 DB 中找,找出后再保存到缓存数据库中是非常常见的需求,这里我们用不到 100 行代码实现这种功能

首先使用一个 Cache 结构体

type Cache struct {
  cacheKeyPrefix string // 用来作为缓存 key 的前缀
  cacheClient    client.CacheClient // 缓存客户端
}

cache 结构体上需要实现一个 fetch 的方法,fetch 方法接受 4 个参数,分别是:缓存 key,返回的结果,缓存超时时间,以及未找到缓存时的回调函数

// key 是缓存的 key, result 是最终返回的结果,可以是任意的类型, ex 设置缓存时的超时时间,fn 是从 DB 获取数据的函数
func (c *Cache) Fetch(ctx context.Context, key string, result interface{}, ex time.Duration, fn func() (rawResult interface{}, err error)) (ok bool, err error) {
  returnValue := reflect.ValueOf(result).Elem()

  // 如果缓存中数据存在,直接返回
  exist, err := c.Get(ctx, key, &result)
  if err != nil {
    return false, err
  }
  if exist {
    return true, nil
  }

  // 从数据库中获取数据
  data, err := fn()
  if err != nil {
    return data, err
  }
  // 获取到的数据塞入缓存中
  err = c.Set(ctx, key, data, ex)
  if err != nil {
    return
  }
  // 设置返回值
  returnValue.Set(reflect.ValueOf(data))
  return true, nil
}


func (c *Cache) Get(ctx context.Context, key string, returnValue interface{}) (exist bool, err error) {
	result := c.cacheClient.Get(ctx, c.cacheKey(key))
	if result == "" {
		return false, nil
	}

  // 使用 msgpack 减少数据的大小
	err = msgpack.Unmarshal([]byte(result), returnValue)
	if err != nil {
		return false, err
	}

	return true, nil
}

func (c *Cache) Set(ctx context.Context, key string, value interface{}, ex time.Duration) (err error) {
	bytes, err := msgpack.Marshal(value)
	if err != nil {
		return
	}

	return c.cacheClient.Set(ctx, c.cacheKey(key), bytes, ex)
}

func (c *Cache) cacheKey(key string) string {
	return fmt.Sprintf("%s:%s", c.cacheKeyPrefix, key)
}

缓存数据库可以是 Redis 或者其他类型的缓存数据库,只要实现了下面的 interface 就可以

type CacheClient interface {
  Get(ctx context.Context, key string) string // 获取缓存数据
  Set(ctx context.Context, key string, value interface{}, ex time.Duration) error // 设置缓存数据
  Del(ctx context.Context, key string) error // 删除缓存数据
}

接下来我们以 Redis 为例,首先实现 CacheClient 的方法


import (
  "context"
  "time"

  redis "github.com/go-redis/redis/v8"
)

type RedisClient struct {
  *redis.Client
}

func NewRedis(client *redis.Client) *RedisClient {
  return &RedisClient{
    Client: client,
  }
}

func (r *RedisClient) Get(ctx context.Context, key string) string {
  result, err := r.Client.Get(ctx, key).Result()
  if err != nil {
    return ""
  }

  return result
}

func (r *RedisClient) Set(ctx context.Context, key string, value interface{}, ex time.Duration) error {
  return r.Client.SetEX(ctx, key, value, ex).Err()
}

func (r *RedisClient) Del(ctx context.Context, key string) error {
  return r.Client.Del(ctx, key).Err()
}

然后可以按以下方式使用

  func NewCache(cacheClient client.CacheClient, cacheKeyPrefix string) *Cache {
    if cacheKeyPrefix == "" {
      cacheKeyPrefix = "go-cache-fetch"
    }

    return &Cache{
      cacheKeyPrefix: cacheKeyPrefix,
      cacheClient:    cacheClient,
    }
  }

  redisClient := NewRedis(redisClient)
  cache := newCache(redisClient, "prefix")

  // 使用 fetch
  type Data struct {}
  var result []Data
  ok, err := cache.Fetch(ctx, "test", &result, time.Minute, func() (rawResult interface{}, err error) {
    return getDataFromDB()
  })

  // 最终的 result 就是我们需要的数据

到这里,我们还会发现一个问题,如果 "test" 的缓存失效,在秒杀或者其他高并发场景下,同一时间有大量的请求会到 DB 中,短时间内可能会造成 DB 的 CPU 飙升,影响到其他服务,这就是常说的缓存穿透,我们再加 2 行代码就可以防止这种情况的发生。

type Cache struct {
	g              singleflight.Group
	cacheKeyPrefix string
	cacheClient    client.CacheClient
}

func (c *Cache) Fetch(ctx context.Context, key string, result interface{}, ex time.Duration, fn func() (rawResult interface{}, err error)) (ok bool, err error) {
	returnValue := reflect.ValueOf(result).Elem()

	exist, err := c.Get(ctx, key, &result)
	if err != nil {
		return false, err
	}
	if exist {
		return true, nil
	}

	// 防止缓存穿透
	res, err, _ := c.g.Do(key, func() (interface{}, error) {
		data, err := fn()
		if err != nil {
			return data, err
		}
		err = c.Set(ctx, key, data, ex)
		return data, err
	})
	if err != nil {
		return
	}

	returnValue.Set(reflect.ValueOf(res))

	return true, nil
}

singleflight 常用来将相同的并发请求合并成一个请求,进而减少对下层服务的压力,通常用于解决缓存击穿的问题

以上就实现了一个相对完整的功能满足我们的需求拉。

更多的实现细节和用例请到 github 查看和下载使用