我们先来看一下这段代码
func Index(context *context.Context) *response.Response {
l := lock.NewLock("test", 1*time.Second)
defer l.Release()
if l.Get() {
time.Sleep(4 * time.Second)
return response.Resp().String("拿锁成功")
}
return response.Resp().String("拿锁失败")
}
在上面的控制器中,我们申请了一个过期时间为1s的锁,拿到锁之后开始进行业务操作,由于错误的判断,业务处理的时间超过的锁的过期时间。在这个时刻,有其他请求进来了,同样拿到了锁,那么在同一时刻,会有多个请求来处理这个业务,这个显然不是我们想要的,也违背了分布式锁的原则。
一般这种情况,有两种解决方案。一个是加大锁的过期时间,但是这个方案有太多的不确定性,所以用第二种,动态的给锁的过期时间续期。
先整理一下思路,实现的流程是,当一个请求获取到锁之后,开启一个协程,这个协程会有一个等待,当锁快要到期时,检查这个锁是否还存在,如果存在就给锁续期。
实现检查锁的方法
//检查锁是否被释放,未被释放就延长锁时间
func (lk *lock) checkLockIsRelease() {
for {
//检查的时间不能和锁的时间一致,要提前一点,否则锁会过期
checkCxt, _ := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(lk.expiration.Milliseconds()-lk.expiration.Milliseconds()/10))
select {
case <-checkCxt.Done():
//多次续期,直到锁被释放
isContinue := lk.done()
if !isContinue {
return
}
}
}
}
//判断锁是否已被释放
func (lk *lock) done() bool {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
res, err := redis.Client().Exists(cxt, lk.key).Result()
cancel()
if err != nil {
return false
}
if res == 1 {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ok, err := redis.Client().Expire(cxt, lk.key, lk.expiration).Result()
cancel()
if err != nil {
return false
}
if ok {
fmt.Println("续期")
return true
}
}
return false
}
接着在获取锁成功的地方调用
// Get 获取锁
func (lk *lock) Get() bool {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ok, err := redis.Client().SetNX(cxt, lk.key, lk.requestId, lk.expiration).Result()
if err != nil {
return false
}
if ok {
//锁续期检查
go lk.checkLockIsRelease()
}
return ok
}
到这里还有两个问题
- 锁被手动释放后,检查锁续期的协程还没有结束,造成了资源的浪费
- 锁的手动释放和锁的续期有竞争,需要加锁
修改分布式锁的结构体
type lock struct {
key string
expiration time.Duration
requestId string
checkCancel chan bool //释放续期协程的信道
mu sync.Mutex //锁
}
添加取消方法和加锁
//检查锁是否被释放,未被释放就延长锁时间
func (lk *lock) checkLockIsRelease() {
for {
checkCxt, _ := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(lk.expiration.Milliseconds()-lk.expiration.Milliseconds()/10))
lk.checkCancel = make(chan bool)
select {
case <-checkCxt.Done():
//多次续期,直到锁被释放
isContinue := lk.done()
if !isContinue {
return
}
//取消
case <-lk.checkCancel:
fmt.Println("释放")
return
}
}
}
//判断锁是否已被释放
func (lk *lock) done() bool {
//加锁
lk.mu.Lock()
defer lk.mu.Unlock()
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
res, err := redis.Client().Exists(cxt, lk.key).Result()
cancel()
if err != nil {
return false
}
if res == 1 {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ok, err := redis.Client().Expire(cxt, lk.key, lk.expiration).Result()
cancel()
if err != nil {
return false
}
if ok {
fmt.Println("续期")
return true
}
}
return false
}
释放锁也要修改
// Release 释放锁
func (lk *lock) Release() error {
lk.mu.Lock()
defer lk.mu.Unlock()
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
const luaScript = `
if redis.call('get', KEYS[1])==ARGV[1] then
return redis.call('del', KEYS[1])
else
return 0
end
`
script := goredis.NewScript(luaScript)
res, err := script.Run(cxt, redis.Client(), []string{lk.key}, lk.requestId).Result()
if res.(int64) != 0 {
lk.checkCancel <- true
}
return err
}
运行开头中的控制器代码,控制台输出
模拟一秒发出一个请求
package main
import (
"fmt"
"github.com/PeterYangs/tools/http"
"sync"
"time"
)
func main() {
client := http.Client()
wait := sync.WaitGroup{}
for i := 0; i < 4; i++ {
wait.Add(1)
go func(index int) {
defer wait.Done()
time.Sleep(time.Duration(index) * time.Second)
str, _ := client.Request().GetToString("http://127.0.0.1:8080")
fmt.Println("请求", index, str)
}(i)
}
wait.Wait()
fmt.Println("finish")
}
控制台输出
完整代码
package lock
import (
"context"
"fmt"
goredis "github.com/go-redis/redis/v8"
uuid "github.com/satori/go.uuid"
"myGin/redis"
"sync"
"time"
)
type lock struct {
key string
expiration time.Duration
requestId string
checkCancel chan bool
mu sync.Mutex
}
func NewLock(key string, expiration time.Duration) *lock {
requestId := uuid.NewV4().String()
return &lock{key: key, expiration: expiration, requestId: requestId, mu: sync.Mutex{}}
}
// Get 获取锁
func (lk *lock) Get() bool {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ok, err := redis.Client().SetNX(cxt, lk.key, lk.requestId, lk.expiration).Result()
if err != nil {
return false
}
if ok {
//锁续期检查
go lk.checkLockIsRelease()
}
return ok
}
//检查锁是否被释放,未被释放就延长锁时间
func (lk *lock) checkLockIsRelease() {
for {
checkCxt, _ := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(lk.expiration.Milliseconds()-lk.expiration.Milliseconds()/10))
lk.checkCancel = make(chan bool)
select {
case <-checkCxt.Done():
//多次续期,直到锁被释放
isContinue := lk.done()
if !isContinue {
return
}
//取消
case <-lk.checkCancel:
fmt.Println("释放")
return
}
}
}
//判断锁是否已被释放
func (lk *lock) done() bool {
lk.mu.Lock()
defer lk.mu.Unlock()
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
res, err := redis.Client().Exists(cxt, lk.key).Result()
cancel()
if err != nil {
return false
}
if res == 1 {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ok, err := redis.Client().Expire(cxt, lk.key, lk.expiration).Result()
cancel()
if err != nil {
return false
}
if ok {
fmt.Println("续期")
return true
}
}
return false
}
// Block 阻塞获取锁
func (lk *lock) Block(expiration time.Duration) bool {
t := time.Now()
for {
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ok, err := redis.Client().SetNX(cxt, lk.key, lk.requestId, lk.expiration).Result()
cancel()
if err != nil {
return false
}
if ok {
go lk.checkLockIsRelease()
return true
}
time.Sleep(200 * time.Millisecond)
if time.Now().Sub(t) > expiration {
return false
}
}
}
// ForceRelease 强制释放锁,忽略请求id
func (lk *lock) ForceRelease() error {
lk.mu.Lock()
defer lk.mu.Unlock()
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := redis.Client().Del(cxt, lk.key).Result()
lk.checkCancel <- true
return err
}
// Release 释放锁
func (lk *lock) Release() error {
lk.mu.Lock()
defer lk.mu.Unlock()
cxt, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
const luaScript = `
if redis.call('get', KEYS[1])==ARGV[1] then
return redis.call('del', KEYS[1])
else
return 0
end
`
script := goredis.NewScript(luaScript)
res, err := script.Run(cxt, redis.Client(), []string{lk.key}, lk.requestId).Result()
if res.(int64) != 0 {
lk.checkCancel <- true
}
return err
}