golang context.Context + etcd 实现简单的实时任务启停

991 阅读2分钟

工作中用go写了很多服务,也抽象出了一些简单的实时启停编程模型。

主要利用到golang标准库中的context.Context,这绝对是一个好东西,用来控制goroutine的生命周期,配合起来很方便使用,兼具可观察性(方便debug)。

用到etcd主要也是公司的中间件直接就提供了它,go实现很轻量,并且golang-client api也很友好,直接支持watch操作,可以用来通知程序任务的关闭和开启。注意用的时候需要根据场景配置好持久化。(程序初次启动可以通过它获取要运行的任务)

以下代码以实现一个从kafka中消费消息,定时或者定量上传到s3的程序为例。同步的任务都可以通过网页动态开启和关闭。

稍微讲述一下这里的套路,main.go中,先开辟一个 rootCtx

package main

import (
	"context"
	"log"
	"os"
	"os/signal"
	"syscall"
)

var (
	rootCtx context.Context
)

func init() {
	rootCtx = context.Background()
}

func main() {
	ctx, cancel := context.WithCancel(rootCtx)

	go task.InitAndWatch(ctx) // 这里写拉起任务的具体代码

	termChan := make(chan os.Signal, 1)
	signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)

	log.Println("exiting:", <-termChan)
	cancel()
	log.Println("exited")
}

记住这个程序后面所有的goroutine,都要从rootCtx派生出去。

context.WithCancel是我们用到的核心api,如果cancel了父级ctx,那么基于这个ctx派生出去的所有子ctx,都会接收到关闭信号。

这里会有一些捕获关闭程序信号的监听代码,不管你的代码是容器部署还是传统部署,都建议这么做,尽量实现优雅关闭。

func InitAndWatch(ctx context.Context) {
	var err error
	etcdClient, err = clientv3.New(clientv3.Config{
		Endpoints:   conf.Etcd.Urls,
		DialTimeout: 5 * time.Second,
		DialOptions: []grpc.DialOption{grpc.WithBlock()}, // 需要注意带上withBlock,配合超时,保证直接连接上etcd。
	})

	if err != nil {
		panic(fmt.Errorf("无法连接etcd: %w", err))
	}

	// 这里我的任务列表是通过etcd持久化了,所以我直接透过etcd获取所有的任务
    // etcd存储的其实是类似一个文件夹目录的结构
    // 这里的 conf.Etcd.KeyPrefix 是我们的任务目录前缀 比如 “/s3-sync/job/”
	rangeResp, err := etcdClient.Get(ctx, conf.Etcd.KeyPrefix, clientv3.WithPrefix(), clientv3.WithKeysOnly())
	if err != nil {
		panic(fmt.Errorf("首次初始化s3同步任务失败: %w", err))
	}

	for _, kv := range rangeResp.Kvs {
		jobID, err := parseIDFromKey(string(kv.Key)) // 这里实现你自己的逻辑,解析出任务id
		if err != nil {
			continue
		}

		err = startAndStore(ctx, jobID) // 这里传入ctx,开启任务的goroutine
		if err != nil {
			logger.Error("can't startAndStore job", zap.Error(err), zap.Int64("jobID", jobID))
		}
	}

	currentRevision := rangeResp.Header.Revision + 1 // 我们观察etcd的下一个revision周期

	watcher := clientv3.NewWatcher(etcdClient)
	watchChan := watcher.Watch(ctx, conf.Etcd.KeyPrefix, clientv3.WithPrefix(), clientv3.WithKeysOnly(), clientv3.WithRev(currentRevision))

	for watchResp := range watchChan {
		for _, ev := range watchResp.Events {
			jobID, err := parseIDFromKey(string(ev.Kv.Key))
			if err != nil {
				continue
			}

			switch ev.Type {
			case mvccpb.PUT:
				logger.Info("观察到PUT事件", zap.Int64("jobID", jobID))
				// 需要检查jobID是否已存在
				if oldJob, ok := Tasks.Load(jobID); ok {
					oldJob.stopAndDelete()
				}

				err = startAndStore(ctx, jobID)
				if err != nil {
					logger.Error("启动任务出错", zap.Error(err))
				}
			case mvccpb.DELETE:
				logger.Info("观察到DELETE事件", zap.Int64("jobID", jobID))
				// 需要检查jobID是否已存在
				if oldJob, ok := Tasks.Load(jobID); ok {
					oldJob.stopAndDelete()
				}
			}
		}
	}

}

这一段代码稍微有些长,它有以下几个职责:

  • 初次启动从etcd获取所有任务,并通过for循环启动
  • 启动完现有的所有任务之后,开始观察etcd中我们指定的目录前缀,发现PUT事件则开启新任务,发现DELETE则检查关闭老的任务。
  • 所有新启动的任务,都带上了main函数中传入的 可关闭的 ctx。

读到这里不难发现,如果我们杀掉这个程序(Ctrl+C 或者 发送kill信号),所有任务的ctx将会收到cancel信号。

一些例子代码出用到的数据结构:

Job结构体

type Job struct {
	ctx    context.Context    // 核心
	cancel context.CancelFunc // 核心

	mu sync.RWMutex

	meta      *db.JobConfig
	startedAt time.Time

	logger *zap.Logger

	consumer   *kafka.Consumer // 这里和你的具体任务相关
	partitions mapset.Set

	tmp []byte

	offsets map[int32]kafka.TopicPartition

	svc      *s3.S3
	uploader *s3manager.Uploader
}

存储当前所有任务的一个map

var Tasks sync.Map // key为 jobID

下面再分析以下具体启动单个任务的代码:

func startAndStore(ctx context.Context, jobID int64) error {
	jobConfig, err := db.GetSyncJobInfo(jobID) // 从你的数据源获取任务的具体配置信息
	if err != nil {
		return fmt.Errorf("query database error: %w", err)
	}

	logger := logger.With(
		zap.Int64("jobID", jobID),
		zap.String("topic", jobConfig.Topic),
		zap.String("auto.offset.reset", jobConfig.AutoOffsetReset),
		zap.String("consumer group", jobConfig.ConsumerGroup),
		zap.Strings("brokers", jobConfig.Brokers),
	)

	logger.Info("query database success", zap.Any("config", jobConfig))

	ctx, cancel := context.WithCancel(ctx) // 这里需要仔细看,我们基于传入的ctx又派生了一个可以cancel的ctx。
    
	task := &Job{
		ctx:    ctx,
		cancel: cancel,

		meta: jobConfig,

		mu: sync.RWMutex{},

		startedAt: time.Now().Local(),

		logger:     logger,
		partitions: mapset.NewSet(),
	}

	task.svc, task.uploader = sink.NewUploader()

	consumer, err := task.NewConsumerForTopic()
	if err != nil {
		logger.Error("无法建立kafka客户端", zap.Error(err))
		return err
	}

	task.consumer = consumer

	Tasks.Store(jobID, task) // 为了方便debug将开启的任务存入一个sync.Map中,这样可以通过遍历这个Map知道当前有哪些任务在跑。

	// 开始工作携程
	task.start()

	logger.Info("成功开启并保存到Map")

	return nil
}

这里是关闭单个任务的代码:

func (job *Job) stopAndDelete() {
	job.cancel() // 核心
	Tasks.Delete(job.meta.ID)
	job.logger.Info("成功关闭并从Map删除")
}

再接下来看一下 task.start(),在这个函数中,我们将通过 for select case的组合,来达到工作的同时,监听关闭信号。

func (job *Job) start() {
	go func() {

	ResetLoop:
		for {
			// 清理历史缓存
			job.tmp = job.tmp[:0]
			job.offsets = make(map[int32]kafka.TopicPartition)

			job.logger.Info("reset context.withTimeout")
			ctx, cancel := context.WithTimeout(job.ctx, job.ValidS3ObjectRollInterval()) // 在我的任务中,这里会更加复杂,要求每过固定的一个间隔,就上传一次。

		PollLoop:
			select {
			case <-job.ctx.Done():
				job.logger.Info("收到网页上的关闭任务信号")
				// 上传内存里的内容并且提交offset

				err := job.consumer.Close()
				if err != nil {
					job.logger.Error("关闭kafka客户端出错", zap.Error(err))
				}
				cancel()
				job.logger.Info("成功关闭任务")
				job.tmp = nil
				return

			case <-ctx.Done():
				// 上传并且提交offsets
				job.mu.Lock()

				job.logger.Info("需要上传")

				if len(job.tmp) > 0 {
					// 开始上传
					now := time.Now().Local()

					bucket := aws.String(job.meta.S3BucketName)
					key := aws.String(fmt.Sprintf("%s/%s/%s/%s-%s_%s.%d.%s",
						job.meta.S3FolderName,
						job.meta.Topic,
						now.Format("20060102"),
						job.meta.S3ObjectPrefix,
						now.Format("15-04"),
						strings.ReplaceAll(conf.Server.PodIP, ".", "-"),
						now.UnixNano()/int64(time.Millisecond),
						job.meta.S3ObjectSuffix),
					)

					res, err := job.uploader.UploadWithContext(job.ctx, &s3manager.UploadInput{
						Body:   bytes.NewReader(job.tmp),
						Bucket: bucket,
						Key:    key,
					}, func(uploader *s3manager.Uploader) {
						uploader.PartSize = int64(job.ValidS3ObjectRollSize() * 1024 * 1024 / 5)
						uploader.Concurrency = 5
					})
					if err != nil {
						job.logger.Error("上传出错", zap.Error(err))
						monitor.S3UploadErrorCounter.WithLabelValues(job.meta.Topic).Inc()

						job.mu.Unlock()
						goto ResetLoop
					}
					if res != nil {
						monitor.S3UploadSuccessCounter.WithLabelValues(job.meta.Topic).Inc()
						job.logger.Info("上传成功", zap.String("url", res.Location))
					}

					offsets := make([]kafka.TopicPartition, 0)
					for k, v := range job.offsets {
						if !job.partitions.Contains(k) {
							continue
						}
						offsets = append(offsets, kafka.TopicPartition{
							Topic:     &job.meta.Topic,
							Partition: k,
							Offset:    v.Offset + 1, // 需要 +1
						})
					}

					committed, err := job.consumer.CommitOffsets(offsets)
					if err != nil {
						job.logger.Error("提交offsets出错", zap.Error(err), zap.Any("offset", offsets))
						monitor.CommitErrorCounter.WithLabelValues(job.meta.Topic).Inc()

						// TODO: 删除已经上传的文件 没有经过测试 不知道有没有删除权限
						_, err = job.svc.DeleteObject(&s3.DeleteObjectInput{
							Bucket: bucket,
							Key:    key,
						})
						if err != nil {
							job.logger.Error("删除文件出错", zap.Error(err), zap.String("bucket", *bucket), zap.String("key", *key))
						}

						job.mu.Unlock()
						goto ResetLoop
					}

					monitor.CommitSuccessCounter.WithLabelValues(job.meta.Topic).Inc()
					job.logger.Info("提交offsets成功", zap.Any("offsets", committed))
				} else {
					job.logger.Info("消息量为0,跳过上传和提交offsets")
				}

				// 重要:解锁
				job.mu.Unlock()

				job.logger.Info("重设context,继续主循环")

				goto ResetLoop

			default:
				//workerLogger.Info("startAndStore polling...")
				ev := job.consumer.Poll(100)
				if ev == nil {
					//logger.Info("nil event")
					goto PollLoop
				}
				switch e := ev.(type) {
				case *kafka.Message:
					if len(e.Value) == 0 {
						job.logger.Info("消息为空,跳过", zap.String("msg", e.String()))
						break
					}
					monitor.PollSuccessCounter.WithLabelValues(job.meta.Topic, strconv.FormatInt(int64(e.TopicPartition.Partition), 10)).Inc()

					if e.TopicPartition.Error != nil {
						job.logger.Error("TopicPartition.Error", zap.Error(e.TopicPartition.Error))
					}

					job.offsets[e.TopicPartition.Partition] = e.TopicPartition

					if len(job.tmp) > 0 {
						// 前面有消息就增加换行符
						job.tmp = append(job.tmp, []byte("\n")...)
					}

					job.tmp = append(job.tmp, e.Value...)

					size := len(job.tmp)

					if size >= job.ValidS3ObjectRollSize()*1024*1024 {
						job.logger.Info("达到约定大小,需要上传,内存里msg大小",
							zap.Int("size", size),
							zap.Int("validS3ObjectRollSize", job.ValidS3ObjectRollSize()),
						)
						cancel()
					}

				case kafka.Error:
					job.logger.Error("消费出错", zap.String("error", e.Error()))
					monitor.PollErrorCounter.WithLabelValues(job.meta.Topic).Inc()

				default:
					job.logger.Warn("未处理的事件", zap.Any("event", e))
				}

				goto PollLoop
			}
		}

	}()
}

还有一个帮助函数用于定时dump工作任务,方便debug。

func PollState() {
	c := cron.New(cron.WithChain(
		cron.Recover(cron.DefaultLogger),
	))
	// 每分钟末执行
	_, err := c.AddFunc("* * * * *", func() {
		// startAndStore to log state
		logger.Info("开始打印所有jobs")

		Tasks.Range(func(jobID int64, job *Job) bool {
			logger.Info("job state", zap.Int64("jobID", jobID), zap.String("partitions", job.partitions.String()))
			return true
		})

		logger.Info("结束打印所有jobs")
	})

	if err != nil {
		panic(err)
	}

	c.Start()
}