基于go-redis实现分布式锁

1,101 阅读5分钟

分布式锁是什么

我们都知道,在业务开发中,多线程去处理共享数据,为保证数据安全,可以采用加锁来解决,例如go的map和slice操作,如果多线程操作map不加锁的话甚至会发生panic,操作slice也可能会导致结果和预想的不一致,go提供了sync.mutex包来实现了锁机制。但这都是单进程的锁,那么在分布式场景下呢,如下图

1675923161229.jpg

怎么实现分布式锁

目前流行的有几种方案,数据库、redis、zookeeper

使用redis怎么实现

1、加锁

我们可以使用 SETNX 命令,这个命令表示SET if Not eXists,即如果 key 不存在,才会设置它的值,否则什么也不做。两个客户端进程可以执行这个命令,达到互斥,就可以实现一个分布式锁。这时候的整体流程是这样:

image.png

2、增加“租期”

如上图,如果在处理业务这一步出现问题,例如服务器宕机,就会发生死锁问题,其他服务再也获取不到锁了。

我们可以给锁加一个过期时间,例如我们预估处理业务需要1s,我们可以设置过期时间为2s,这样的话不管处理业务 会出现什么问题,都可以保证锁的正常释放。

这样其实又引出来另外一个问题,setnx和expire并不是原子操作,还是有可能会出现加锁成功,设置过期时间失败的情况,所幸Redis 2.6.12 之后,Redis 扩展了 SET 命令的参数,用一条命令就可以解决:

SET lock 1 EX 10 NX

其实还有另外的方法可以解决,这里不再赘述。

3、解锁

业务处理完成后使用delete删除redis的key,看起来已经足够了,其实不然,我们从一张图来看一下,在某些情况下会发生什么:

image.png 我们可以看到因为第三步的a进程的业务处理时间超过了10s,所以进程b又加锁成功了,然后等第6步进程a去解锁的时候其实删除的是进程b设置的key。
怎么解决这个问题呢,我们再来看一张图:

image.png 我们稍微修改了一下这个流程,在第一步加锁的时候给value设置了一个随机值(注:真实的写法肯定不是直接set random(),redis也没有这个写法,这么写只是代表这里传了个随机值,具体可以一会看代码),然后在第6步的时候先检查值是否相等,再做删除。
其实这块还是有问题,第6步里的get和delete其实不是一个原子操作,极端情况下,刚好当进程a get完了准备删除的时候key失效了,这个时候进程b加锁成功了,那进程a还是释放了进程b的锁,所以这块需要使用lua脚本封装一下,让get和delete封装到一个原子操作中。

if redis.call('get', KEYS[1]) == ARGV[1]
   then redis.call('del', KEYS[1]) return 1
else
   return 0
end

到这里就结束吗?上文也只是避免了线程A误删掉key的情况,但是同一时间有 A,B 两个线程在访问代码块,仍然是不完美的,怎么解决呢?

看门狗

我们可以让获得锁的线程开启一个守护线程,用来给快要过期的锁“续期”。
假设进程a执行了7s后还没执行完,这时候守护线程会执行 expire 指令,为这把锁续期10s,这个守护线程每7s执行一次。这样的话当线程A执行完任务,直接关掉守护线程,如果线程A在执行业务时出现问题,例如宕机,由于线程 A 和守护线程在同一个进程,守护线程也会停下。这把锁到了超时的时候,没有给它续期,也会自动释放。

一般我们定时器轮询的时间为 2/3*expire

使用go-redis实现分布式锁的代码

package dispersed_lock

import (
	"context"
	"fmt"
	"github.com/go-redis/redis/v8"
	"github.com/lfxnxf/while"
	"math/rand"
	"sync"
	"time"
)

const (
	// 解锁lua
	unLockScript = "if redis.call('get', KEYS[1]) == ARGV[1] " +
		"then redis.call('del', KEYS[1]) return 1 " +
		"else " +
		"return 0 " +
		"end"

	// 看门狗lua
	watchLogScript = "if redis.call('get', KEYS[1]) == ARGV[1] " +
		"then return redis.call('expire', KEYS[1], ARGV[2]) " +
		"else " +
		"return 0 " +
		"end"

	lockMaxLoopNum = 1000 //加锁最大循环数量
)

var scriptMap sync.Map

type option func() (bool, error)

type DispersedLock struct {
	key            string        // 锁key
	value          string        // 锁的值,随机数
	expire         int           // 锁过期时间,单位秒
	lockClient     redis.Cmdable // 锁客户端,暂时只有redis
	unLockScript   string        // lua脚本
	watchLogScript string        // 看门狗lua
	unlockCh       chan struct{} // 解锁通知通道
}

// 生成分布式锁对象
// @client go-redis实例
// @key 锁的key
// @exoire 过期时间
func New(ctx context.Context, client redis.Cmdable, key string, expire int) *DispersedLock {
	d := &DispersedLock{
		key:    key,
		expire: expire,
		value:  fmt.Sprintf("%d", Random(100000000, 999999999)), // 随机值作为锁的值
	}

	//初始化连接
	d.lockClient = client

	//初始化lua script
	lockScript, _ := scriptMap.LoadOrStore("dispersed_lock", d.getScript(ctx, unLockScript))
	watchLogScript, _ := scriptMap.LoadOrStore("watch_log", d.getScript(ctx, watchLogScript))

	d.unLockScript = lockScript.(string)
	d.watchLogScript = watchLogScript.(string)

	d.unlockCh = make(chan struct{}, 0)

	return d
}

func (d *DispersedLock) getScript(ctx context.Context, script string) string {
	scriptString, _ := d.lockClient.ScriptLoad(ctx, script).Result()
	return scriptString
}

//加锁
func (d *DispersedLock) Lock(ctx context.Context) bool {
	ok, _ := d.lockClient.SetNX(ctx, d.key, d.value, time.Duration(d.expire)*time.Second).Result()
	if ok {
		go d.watchDog(ctx)
	}
	return ok
}

//循环加锁
//@sleepTime int 循环等待时间,单位毫秒
func (d *DispersedLock) LoopLock(ctx context.Context, sleepTime int) bool {
	t := time.NewTicker(time.Duration(sleepTime) * time.Millisecond)
	w := while.NewWhile(lockMaxLoopNum)
	w.For(func() {
		if d.Lock(ctx) {
			t.Stop()
			w.Break()
		} else {
			<-t.C
		}
	})
	if !w.IsNormal() {
		return false
	}
	return true
}

//解锁
func (d *DispersedLock) Unlock(ctx context.Context) bool {
	args := []interface{}{
		d.value, // 脚本中的argv
	}
	flag, _ := d.lockClient.EvalSha(ctx, d.unLockScript, []string{d.key}, args...).Result()
	// 关闭看门狗
	clese(d.unlockCh)
	return lockRes(flag.(int64))
}

//看门狗
func (d *DispersedLock) watchDog(ctx context.Context) {
	// 创建一个定时器NewTicker, 每过期时间的3分之2触发一次
	loopTime := time.Duration(d.expire*1e3*2/3) * time.Millisecond
	expTicker := time.NewTicker(loopTime)
	//确认锁与锁续期打包原子化
	for {
		select {
		case <-expTicker.C:
			args := []interface{}{
				d.value,
				d.expire,
			}
			res, err := d.lockClient.EvalSha(ctx, d.watchLogScript, []string{d.key}, args...).Result()
			if err != nil {
				fmt.Println("watchDog error", err)
				return
			}
			r, ok := res.(int64)
			if !ok {
				return
			}
			if r == 0 {
				return
			}
		case <-d.unlockCh: //任务完成后用户解锁通知看门狗退出
			return
		}
	}
}

func lockRes(flag int64) bool {
	if flag > 0 {
		return true
	} else {
		return false
	}
}

func Random(min, max int64) int64 {
	rand.Seed(time.Now().UnixNano())
	return rand.Int63n(max-min+1) + min
}

调用代码,获取锁失败直接结束:

dispersedLock := New(ctx, client, key, 1)
if !dispersedLock.Lock(ctx) {
   return
}
defer dispersedLock.Unlock(ctx)

获取锁失败轮询等待:

dispersedLock := New(ctx, client, key, 1)
if !dispersedLock.LoopLock(ctx, 10) {
   log.Errorw("loop lock error")
   return nil, zd_error.ServerError
}
defer dispersedLock.Unlock(ctx)