grpc 拦截器: 参数校验、日志上报、panic拦截 等等逻辑的另一种姿态

4,565 阅读7分钟

为什么需要拦截器

先把问题抛出来.

  1. 程序直接 panic 头很疼.
  2. 参数校验的代码哪里都开花, 欣赏不过来了.
  3. 一个请求不设定超时时间, 服务器堆积的请求太多了.
  4. 错误的统一处理满天飞.
  5. 日志和业务的代码到糅合在一起了.
  6. .... 上面的事情不想再做了, 🤮了.

程序 panic 了

func (g greeterImpl) OutOfIndex(ctx context.Context, request *pb.OutOfIndexRequest) (*pb.OutOfIndexResponse, error) {
	request.Ids = make([]int64, 0)
	request.Ids[1] = 1
	return &pb.OutOfIndexResponse{Data: "ok"}, nil
}

上面的这一段代码, 一旦运行整个程序就 Gg 了, 怎么处理呢? docker-compose doc 里面[好像有个, 忘记了, 真的忘记了] alway restart 可以让程序持续运行, 但是这样也有问题啊. 好像没有什么问题. 不就是让已经挂了的程序恢复么? 没有什么问题啊. 其他已经处理了一部分的义务, 但是没有处理完成的怎么办? 这个问题真的让人很头大.

那好办, 在代码里面做一个 panic 的拦截吧, 不让程序把错误抛到顶层的调用好了. 把代码改成这个样子.

func (g greeterImpl) OutOfIndex(ctx context.Context, request *pb.OutOfIndexRequest) (*pb.OutOfIndexResponse, error) {
	defer func() {
		// 拦截错误
		if err:=recover();err !=nil {
			glog.Errorf("panic:%s\n",string(debug.Stack()))
		}
	}() // 主要是增加 defer 函数的处理
	request.Ids = make([]int64, 0)
	request.Ids[1] = 1
	return &pb.OutOfIndexResponse{Data: "ok"}, nil
}

如果真的是这样写, 真的叫人头大... 万一有 10 个这样的例子, 真的要重复写 10 个这样的么? 真的叫人头大, 什么, 已经精通了 CTRL(command) + C && CTRL(command) + V 好吧. 那没有什么事情了, 也许事情本该如此完全满足需求了啊. 但是可能事情不需要搞得这么复杂, 也许并没有这么复杂....

怎么优雅拦截 panic

使用 grpc.Server 提供的拦截器, grpc 提供的拦截器有两种, 分别对应于 Request-Response模式和 流模式(一般用于大文件传输类型的请求)

type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

怎么使用?

使用 UnaryInterceptor 来举个栗子

srv := grpc.NewServer(grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
	defer func() {
		if err := recover(); err != nil {
			glog.Errorf("method:%s, time:%s, err:%v, fatal%s", info.FullMethod, time.Now().Format("20060102-15:04:05:06"), err,string(debug.Stack()))
		}
	}() // 在这里拦截下行程序抛出的异常
	// 执行对应的业务方法
	resp, err = handler(ctx, req)
	return resp, err
}))

当然了, 可以在里面加入一些其他的操作, 比如把请求日志达到 es, 等等 比如参数校验等等. 避免变量污染, 把参数检验的变量控制在了一个闭包函数里面了

srv := grpc.NewServer(grpc.UnaryInterceptor(func() grpc.UnaryServerInterceptor {
	var (
		validate = validator.New()
		uni      = ut.New(zh.New())
		trans, _ = uni.GetTranslator("zh")
	)
	err := zh_translations.RegisterDefaultTranslations(validate, trans)
	if err != nil {
		panic(err)
	}

	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
		fmt.Printf("ctx %#v, req:%v, info:%#v",ctx,req,info)
		defer func() {
			if err := recover(); err != nil {
				glog.Errorf("method:%s, time:%s, err:%v, fatal%s", info.FullMethod, time.Now().Format("20060102-15:04:05:06"), err,string(debug.Stack()))
			}
		}()
	
		// 参数校验
		if err := validate.Struct(req); err != nil {
			if transErr, ok := err.(validator.ValidationErrors); ok {
				translations := transErr.Translate(trans)
				var buf bytes.Buffer
				for _, s2 := range translations {
					buf.WriteString(s2) // 这个函数写, 是有错误返回的.......
				}
				err = status.New(codes.InvalidArgument, buf.String()).Err()
				return resp, err
			}
			err = status.New(codes.Unknown, fmt.Sprintf("error%s", err)).Err()
			return resp, err
		}

		resp, err = handler(ctx, req)
		return resp, err
	}
}()))

拓展中间件的写法

如果需要处理的事情足够简单, 上面的写法已经足够用了, 但是如果没有这么简单, 比如下面的这种:

srv := grpc.NewServer(grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
		fmt.Printf("ctx %#v, req:%v, info:%#v", ctx, req, info)
		defer func() {
			if err := recover(); err != nil {
				glog.Errorf("method:%s, time:%s, err:%v, fatal%s", info.FullMethod, time.Now().Format("20060102-15:04:05:06"), err, string(debug.Stack()))
			}
		}()
		// 如果没有设定超时, 设置一个超时 6 s
		if _, ok := ctx.Deadline(); !ok {
			var cancel context.CancelFunc
			ctx, cancel = context.WithTimeout(ctx, time.Second*6)
			defer cancel()
		}
		data, err := json.Marshal(req)
		if err != nil {
			err = status.New(codes.Internal, err.Error()).Err()
			return resp, err
		}
		jData := string(data)
		glog.Errorf("method:%s, request:%v", info.FullMethod, jData)
		// 参数校验
		if err := validate.Struct(req); err != nil {
			if transErr, ok := err.(validator.ValidationErrors); ok {
				translations := transErr.Translate(trans)
				var buf bytes.Buffer
				for _, s2 := range translations {
					buf.WriteString(s2)
				}
				err = status.New(codes.InvalidArgument, buf.String()).Err()
				return resp, err
			}
			err = status.New(codes.Unknown, fmt.Sprintf("error%s", err)).Err()
			return resp, err
		}
		start := time.Now()
		resp, err = handler(ctx, req)
		glog.Infof("method:%s, request:%#v, resp:%#v, latency:%v, status:%v", info.FullMethod, req, resp, time.Now().Sub(start), status.Convert(err))
		return resp, err
	}))

其实也没有什么问题啊, 一样可以完成工作的啊.... 但是它可以完成地更加好.. 比如像下面这样....

package main

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"net"
	"os"
	"time"

	"github.com/go-playground/locales/zh"
	ut "github.com/go-playground/universal-translator"
	"github.com/go-playground/validator/v10"
	zh_translations "github.com/go-playground/validator/v10/translations/zh"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"mio_grpc/pb"
)

type greeterImpl struct {
}

// 调用链的处理函数
type HandlerFunc func(*Context)

// 重新包装的 Context;
type Context struct {
	req         interface{}            // 输入参数
	resp        interface{}            // 输出参数
	info        *grpc.UnaryServerInfo  // 服务的信息
	ctx         context.Context        // 服务方法的 上下文信息
	handler     grpc.UnaryHandler      // 对应服务的请求处理
	err         error                  // 错误
	reqJsData   string                 // 输入参数js格式的字符串
	respJsData  string                 // 输入参数的js格式的字符串
	handlerFunc []HandlerFunc          // 这里默认有一个处理 handler
	index       int                    // 当前回调所在的层数
	data        map[string]interface{} // 设置的 data
}

// 新建一个 context
// @param context.Context grpc 方法请求的上下文
// @param req             请求方法的输入参数
// @param resp            请求方法的输出参数
// @param info            grpc 方法名的信息, 里面的参数有 请求的方法名
// @param handler         这个参数是一个函数, 用来调用对应的 grpc 方法
//
// @ret   *Context        返回重新包装的 Context
func newContext(ctx context.Context, req, resp interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) *Context {
	data, _ := json.Marshal(req) // 序列化请求的参数
	return &Context{
		req:         req,
		resp:        resp,
		info:        info,
		err:         nil,
		ctx:         ctx,
		reqJsData:   string(data),
		respJsData:  "",
		handlerFunc: make([]HandlerFunc, 0),
		handler:     handler,
		data:        make(map[string]interface{}, 16),
	}
}

// 参数检查
func validate() HandlerFunc {
	var (
		validate = validator.New() // 校验器
		uni      = ut.New(zh.New())
		trans, _ = uni.GetTranslator("zh") // 中午翻译器
	)
	// 关联校验器和翻译器
	err := zh_translations.RegisterDefaultTranslations(validate, trans)
	if err != nil {
		panic(err)
	}
	return func(c *Context) {
		// 参数校验
		if err := validate.Struct(c.req); err != nil {
			// 判断错误是否是校验字段的错误
			// 如果是参数校验不通过, 把错误信息翻译成对应的说明
			if transErr, ok := err.(validator.ValidationErrors); ok {
				translations := transErr.Translate(trans)
				var buf bytes.Buffer
				for _, s2 := range translations {
					buf.WriteString(s2)
				}
				// grpc 返回参数错误
				err = status.New(codes.InvalidArgument, buf.String()).Err()
				// 提前终止调用
				c.AbortWith(err)
				return
			}
			// 如果校验 grpc 输入参数的时候遇到错误, 但是错误不是翻译成错误相关的, 返回未知错误
			err = status.New(codes.Unknown, fmt.Sprintf("error%s", err)).Err()
			// 提前终止调用
			c.AbortWith(err)
			return
		}
	}

}

// 获取请求参数
func (c *Context) GetReq() interface{} {
	return c.req
}

// 获取请求参数的 js 格式
func (c *Context) GetReqJsData() string {
	if c == nil {
		return ""
	}
	return c.reqJsData
}

// 设置 JsData
func (c *Context) SetReqJsData(str string) {
	if c == nil {
		return
	}
	if json.Valid([]byte(str)) {
		c.reqJsData = str
	}
}

// 设置返回请求参数的 js 格式的字符串
func (c *Context) SetRespJsData(str string) {
	if c == nil {
		return
	}
	if json.Valid([]byte(str)) {
		c.respJsData = str
	}
}

// 获取当前 grpc 需要请求的方法的名称
func (c *Context) FullMethod() string {
	if c == nil || c.info == nil {
		return ""
	}
	return c.info.FullMethod
}

func (c *Context) SetData(key string, value interface{}) {
	if c == nil {
		return
	}
	c.data[key] = value
}

func (c *Context) GetData(key string) interface{} {
	if c == nil {
		return nil
	}
	return c.data[key]
}

// 当前调用链方法所在的层数的下一层
func (c *Context) Next() {
	if c == nil {
		return
	}
	c.index++
	for (c.index) < len(c.handlerFunc) {
		c.handlerFunc[c.index](c)
		c.index++
	}

}

// 提前终止所有的调用,
func (c *Context) AbortWith(err error) {
	const (
		abortLevel = 1 << 32
	)
	c.err = err
	c.index = abortLevel
}

// 模拟日志输出到 es
func log2es() HandlerFunc {
	// 模拟输入日志到 es
	file, err := os.OpenFile("./my.txt", os.O_CREATE|os.O_RDWR|os.O_APPEND, 0766)
	if err != nil {
		panic(err)
	}
	w := bufio.NewWriter(file)
	defer w.Flush()
	return func(c *Context) {
		start := time.Now()

		c.Next() // 请求下一个方法
		_, _ = file.WriteString(
			fmt.Sprintf(
				"method:%s, status:%v, latency:%v, req:%s, resp:%s\n", c.FullMethod(), status.Convert(c.err).Code().String(), time.Now().Sub(start), c.GetReqJsData(), ""))
		//fmt.Println("log2es after ", time.Now(), writeString, err2, c.FullMethod())
	}
}

// 调用 处理函数, 这是默认调用的, 这里默认是所有的 调用都完成的时候才调用的
func procHandler(ctx *Context) {
	ctx.resp, ctx.err = ctx.handler(ctx.ctx, ctx.req)
}

// 包裹处理多个 handler func
func WrapperHandler(handFunc ...HandlerFunc) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
		// 从 grpc 拦截器的回调用获取参数, 新生成一个 Context
		c := newContext(ctx, req, resp, info, handler)

		// 构造一个默认的 拦截 panic 的回调处理, 作为第一个业务处理
		c.handlerFunc = append(c.handlerFunc, func(c *Context) {
			defer func() {
				if err := recover(); err != nil {
					c.AbortWith(status.New(codes.Internal, fmt.Sprintf("errors:%v", err)).Err())
				}
			}()
			c.Next()
		})
		// 将用户输入的处理作为中间的处理
		c.handlerFunc = append(c.handlerFunc, handFunc...)
		// 用户需要处理的完成之后, 调用 真正的服务方法
		c.handlerFunc = append(c.handlerFunc, procHandler)

		// 开始调用服务
		for c.index = 0; c.index < len(c.handlerFunc); c.index++ {
			c.handlerFunc[c.index](c)
		}
		// 返回服务 resp 和 err
		return c.resp, c.err
	}
}

func (g greeterImpl) OutOfIndex(ctx context.Context, request *pb.OutOfIndexRequest) (resp *pb.OutOfIndexResponse, err error) {
	fmt.Println("OutOfIndex", request)
	time.Sleep(time.Second * 4)
	//defer func() {
	//	// 拦截错误
	//	if err := recover(); err != nil {
	//		glog.Errorf("panic:%s\n", string(debug.Stack()))
	//	}
	//}()
	resp = &pb.OutOfIndexResponse{Data: "ok"}
	request.Ids = make([]int64, 0)
	request.Ids[1] = 1
	return
}

func (g greeterImpl) NilPointer(ctx context.Context, request *pb.NilPointerRequest) (*pb.NilPointerResponse, error) {
	request.Data.Data = "work man"
	return &pb.NilPointerResponse{Data: "ok"}, nil

}

func (g greeterImpl) Hello(ctx context.Context, request *pb.HelloRequest) (*pb.HelloResponse, error) {
	//panic("implement me")
	fmt.Println("Hello:", request)
	return &pb.HelloResponse{ErrCode: "err_code"}, nil
}

func main() {
	flag.Parse()
	srv := grpc.NewServer(grpc.UnaryInterceptor(WrapperHandler(log2es(), validate())))
	listen, err := net.Listen("tcp", ":8086")
	if err != nil {
		panic(err)
	}
	pb.RegisterGreeterServer(srv, &greeterImpl{})
	if err = srv.Serve(listen); err != nil {
		panic(err)
	}

}


好像看起来复杂了, 嗯, 确实, 看起来确实是复杂了, 主要是代码行数变多了, 但是

	srv := grpc.NewServer(grpc.UnaryInterceptor(WrapperHandler(func(c *Context) {
		c.Next()
	}, log2es(), validate(), func(c *Context) {
		c.Next()
	}, func(c *Context) {
		fmt.Println("hello ")
		c.Next()
	})))

这里简单了, 而且可以处理更加多的业务逻辑了

其中用来额外的工具来替换 pb.go 的tag protoc-go-inject-tag, 这里主要用来注入了 validate

message HelloRequest {
  //@inject_tag: validate:"required,gte=0"
  int64  id = 1;
  //@inject_tag:validate:"required"
  string user_name = 2;
  //@inject_tag:validate:"required"
  string user_address = 3;
  int64  book_time = 4;
  //@inject_tag:validate:"required"
  string random_str = 5;
}

完整的代码在这里, 戳我直达 🦀🦀了