Singleflight 巧妙解决缓存击穿

4 阅读4分钟

Singleflight 解析

在并发编程中,singleflight 是一个非常实用的工具库(golang.org/x/sync/singleflight),主要用于请求合并

当多个 Goroutine 同时请求同一个 Key(例如查询同一个热门用户的个人信息)时,singleflight 能确保只发起一次实际的函数调用(如查数据库),然后将结果共享给所有等待的 Goroutine。这能有效防止缓存击穿,避免数据库在高并发下瞬间被打崩。

1. 快速上手

下面我们看一个简单的 Demo:模拟 10 个并发请求同时查询同一个 Key,观察实际的数据库调用次数。

package main

import (
	"fmt"
	"log"
	"sync"
	"sync/atomic"
	"time"

	"golang.org/x/sync/singleflight"
)

var (
	dbCallCount int32            // 统计数据库调用次数
	g singleflight.Group // Singleflight 核心对象
)

// getDataFromDB 模拟一个耗时的数据库查询操作
func getDataFromDB(key string) (string, error) {
	log.Printf("正在查询 %s 的数据...", key)
	time.Sleep(1 * time.Second) // 模拟耗时 1 秒
	atomic.AddInt32(&dbCallCount, 1)
	return fmt.Sprintf("Data for %s", key), nil
}

// simulateConcurrentRequests 模拟并发请求
func simulateConcurrentRequests(n int, key string) {
	var wg sync.WaitGroup
	wg.Add(n)

	log.Printf("开始模拟 %d 个并发请求查询 key: %s", n, key)

	for i := 0; i < n; i++ {
		go func(id int) {
			defer wg.Done()

			// 使用 singleflight.Do 来合并请求
			// 只有第一个到达的请求会真正执行 getDataFromDB
			// 后续的请求会等待第一个请求的结果,直接返回
			v, err, shared := g.Do(key, func() (interface{}, error) {
				return getDataFromDB(key)
			})

			if err != nil {
				log.Printf("请求 %d 失败: %v", id, err)
				return
			}

			// shared=true 表示该结果是被多个请求共享的
			log.Printf("请求 %d 完成: 结果=%v, 是否共享=%v", id, v, shared)
		}(i)
	}

	wg.Wait()
	log.Printf("所有请求完成。")
}

func main() {
	key := "user:1001"
	simulateConcurrentRequests(10, key)
	
	count := atomic.LoadInt32(&dbCallCount)
	fmt.Printf("最终数据库调用次数: %d\n", count)
}

2. 实现解析

2.1 核心数据结构

singleflight 的核心是 Group 结构体,管理着所有正在进行的请求。

type Group struct {
	mu sync.Mutex       // 互斥锁,保护 m 的并发读写
	m  map[string]*call // 任务表:存储当前正在进行中的请求 (key -> call)
}

call 结构体是一个正在执行或已完成的任务

type call struct {
	wg sync.WaitGroup // 核心机制:用于阻塞等待结果
	
	// 共享结果字段:执行结束后所有等待的请求都能收到结果
	val interface{}   // 正常的返回值
	err error         // 错误返回值
	
	dups  int             // 统计有多少个请求在等待这个结果(重复请求数)
	chans []chan<- Result // 用于 DoChan(异步模式)的结果通知通道
}

2.2 核心流程 (Do 方法)

Do 是最常用的同步阻塞方法:

  1. 先加锁,看 map 里有没有这个 Key。
  2. 如果 Key 已存在,说明有请求正在查。当前请求不需要自己查,直接调用 c.wg.Wait() 原地阻塞等待。等前面请求查完直接共享结果。
  3. 如果 Key 不存在, 创建一个新的 call 对象,放入 map,并调用 c.wg.Add(1) 阻塞后面的请求。 释放锁(让其他 Key 的请求能进来)。 执行业务函数 (fn)。查完后,把结果填入 call,调用 c.wg.Done() 广播通知所有等待的人。 最后再次加锁,把这个 Key 从 map 中删掉(防止内存泄漏,也为了让下一次请求能重新查最新数据)。
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
	g.mu.Lock()
	// (惰性初始化 map)

	// 1. 如果任务正在进行中,直接等待
	if c, ok := g.m[key]; ok {
		c.dups++
		g.mu.Unlock()
		c.wg.Wait() 
		return c.val, c.err, true
	}

	// 2. 如果任务未开始
	c := new(call)
	c.wg.Add(1)
	g.m[key] = c
	g.mu.Unlock()

	// 3. 执行业务逻辑
	g.doCall(c, key, fn)
	return c.val, c.err, c.dups > 0
}

2.3 异常处理 (doCall 方法)

  1. 如果业务函数崩了(Panic),singleflight 会捕获这个 Panic,并将其抛给所有等待的请求,避免一个请求挂掉导致其他等待者死锁。
  2. 如果业务函数调用了 runtime.Goexit() 退出,singleflight 也可识别并正确清理资源。

3. 解决缓存击穿

在项目中 UserServiceGetProfile 方法中使用了 singleflight 来防止缓存击穿。

场景描述

当某个热门用户(Key)的缓存过期时,如果有 1000 个请求同时涌入:

  • 不使用 Singleflight:1000 个请求全部打到 MySQL
  • 使用 Singleflight:1000 个请求合并为 1 个 MySQL 查询。第 1 个请求去查库,剩下 999 个等待结果。

代码实现

func (s *UserService) GetProfile(ctx context.Context, lg *zap.Logger, uid int) (*models.User, error) {
	// 1. 先查缓存
	user, err := s.userCache.GetProfile(ctx, uid)
	if err == nil && user != nil {
		return user, nil
	}
	
	// 防穿透:如果缓存中为用户不存在,直接返回
	if errors.Is(err, cache.ErrCacheNotFound) {
		return nil, apperrors.NewNotFoundError("用户不存在")
	}

	// 2. 缓存未命中,准备查库
	key := fmt.Sprintf("user:profile:%d", uid)
	
	// Do 方法确保并发请求合并
	val, err, _ := s.sf.Do(key, func() (interface{}, error) {
		// 查数据库
		dbUser, err := s.repo.GetByID(ctx, uid)
		if err != nil {
			// 防穿透:如果数据库也没查到,写入一个空值到缓存(TTL较短)
			if errors.Is(err, gorm.ErrRecordNotFound) {
				go func() {
					ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
					defer cancel()
					_ = s.userCache.SetProfile(ctx, uid, nil)
				}()
				return nil, apperrors.NewNotFoundError("用户不存在")
			}
			return nil, apperrors.NewInternalError("系统错误")
		}

		// 查到了数据,回写缓存 (异步执行,不阻塞当前请求返回)
		go func() {
			ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
			defer cancel()
			// 设置随机过期时间,防止缓存雪崩
			if err := s.userCache.SetProfile(ctx, uid, dbUser); err != nil {
				lg.Warn("get_profile.cache_set_failed", zap.Error(err))
			}
		}()
		return dbUser, nil
	})
	if err != nil {
		return nil, err
	}
	return val.(*models.User), nil
}