grpc服务端客户端使用中间件

237 阅读3分钟

grpc-ecosystem/go-grpc-middleware已经有了很多常见的的中间件

  • grpc_auth
  • grpc_zap
  • grpc_recovery
  • message
  • validation
  • retries

服务端和客户端都已原生支持同时使用多个中间件, 常用中间件:

服务端

  • 记录日志
  • 记录错误
  • recover

客户端

  • 设置超时(当然也可以单独调用接口时,ctx传递)

中间件方法

  • 服务端
    s := grpc.NewServer(
+        grpc.ChainUnaryInterceptor(
            middleware.HelloInterceptor(), //中间件可以写成函数形式, 也可以直接返回
            middleware.HelloInterceptor2,
            //常用的中间件
            middleware.AccessLog,
            middleware.ErrorLog,
            middleware.Recovery,
        ),
    )
func HelloInterceptor()grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
       log.Println("你好")
+       resp, err = handler(ctx, req)
       log.Println("再见")
       return resp, err
    }
}
  • 客户端
    conn, err := grpc.NewClient(*addr,   //新写法取代grpc.Dial
        grpc.WithTransportCredentials(insecure.NewCredentials()),
        grpc.WithBlock(),
 +       grpc.WithChainUnaryInterceptor(
            middleware.UnaryContextTimeout(),
        ),
    )
func HelloClientInterceptor() grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
       fmt.Println("clientinterceptor前")
+       err := invoker(ctx, method, req, resp, cc, opts...)
       fmt.Println("clientinterceptor后")
       return err
    }
}

目录

go mod init grpc-middleware-demo
├── client
│   └── main.go
├── go.mod
├── go.sum
├── main.go
├── middleware
│   ├── client_interceptor.go
│   └── server_interceptor.go
├── pkg
│   └── errcode
│       ├── common_error.go
│       ├── errcode.go
│       └── rpc_error.go
├── proto
│   ├── helloworld.pb.go
│   ├── helloworld.proto
│   └── helloworld_grpc.pb.go
├── readme.md
├── server
│   └── main.go
└── test.json

proto/helloworld.proto

syntax = "proto3";
import "google/protobuf/any.proto";

package helloworld;

option go_package = ".;pb";

// The greeting service definition.
service Greeter {
  // Sends a greeting
  rpc SayHello(HelloRequest) returns (HelloReply) {}
}

// The request message containing the user's name.
message HelloRequest {
  string name = 1;
}

// The response message containing the greetings
message HelloReply {
  string message = 1;
}

message Error {
  int32 code = 1;
  string message = 2;
  google.protobuf.Any detail = 3;
}


//protoc --go_out=./proto --go-grpc_out=./proto ./proto/*.proto

server/main.go

package main

import (
        "context"
        "flag"
        "fmt"
        "grpc-middleware-demo/middleware"
        pb "grpc-middleware-demo/proto"
        "log"
        "net"

        "google.golang.org/grpc"
)

var (
        port = flag.Int("port", 50051, "The server port")
)

// server is used to implement helloworld.GreeterServer.
type server struct {
        pb.UnimplementedGreeterServer
}

// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
        panic("implement me")
        log.Printf("Received: %v", in.GetName())
        return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil
}

func main() {
        flag.Parse()
        lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *port))
        if err != nil {
                log.Fatalf("failed to listen: %v", err)
        }
        s := grpc.NewServer(
                grpc.ChainUnaryInterceptor(
                        middleware.HelloInterceptor,
                        middleware.HelloInterceptor2,
                        //常用的中间件
                        middleware.AccessLog,
                        middleware.ErrorLog,
                        middleware.Recovery,
                ),
        )
        pb.RegisterGreeterServer(s, &server{})
        log.Printf("server listening at %v", lis.Addr())
        if err := s.Serve(lis); err != nil {
                log.Fatalf("failed to serve: %v", err)
        }
}

middleware/server_interceptor.go

package middleware

import (
        "context"
        "google.golang.org/grpc"
        "grpc-middleware-demo/pkg/errcode"
        "log"
        "runtime/debug"
        "time"
)

func HelloInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        log.Println("你好")
        resp, err := handler(ctx, req)
        log.Println("再见")
        return resp, err
}

func HelloInterceptor2(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        log.Println("你好2")
        resp, err := handler(ctx, req)
        log.Println("再见2")
        return resp, err
}

func AccessLog(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        requestLog := "access request log: method: %s, begin_time: %d, request: %v"
        beginTime := time.Now().Local().Unix()
        log.Printf(requestLog, info.FullMethod, beginTime, req)

        resp, err := handler(ctx, req)

        responseLog := "access response log: method: %s, begin_time: %d, end_time: %d, response: %v"
        endTime := time.Now().Local().Unix()
        log.Printf(responseLog, info.FullMethod, beginTime, endTime, resp)
        return resp, err
}

func ErrorLog(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        resp, err := handler(ctx, req)
        if err != nil {
                errLog := "error log: method: %s, code: %v, message: %v, details: %v"
                s := errcode.FromError(err)
                log.Printf(errLog, info.FullMethod, s.Code(), s.Err().Error(), s.Details())
        }
        return resp, err
}

func Recovery(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        defer func() {
                if e := recover(); e != nil {
                        recoveryLog := "recovery log: method: %s, message: %v, stack: %s"
                        log.Printf(recoveryLog, info.FullMethod, e, string(debug.Stack()[:]))
                }
        }()

        return handler(ctx, req)
}

client/main.go

package main

import (
        "context"
        "flag"
        "grpc-middleware-demo/middleware"
        pb "grpc-middleware-demo/proto"
        "log"
        "time"

        "google.golang.org/grpc"
        "google.golang.org/grpc/credentials/insecure"
)

const (
        defaultName = "world"
)

var (
        addr = flag.String("addr", "localhost:50051", "the address to connect to")
        name = flag.String("name", defaultName, "Name to greet")
)

func main() {
        flag.Parse()
        // Set up a connection to the server.
        conn, err := grpc.NewClient(*addr,
                grpc.WithTransportCredentials(insecure.NewCredentials()),
                grpc.WithBlock(),
                grpc.WithChainUnaryInterceptor(
                        middleware.UnaryContextTimeout(),
                ))
        if err != nil {
                log.Fatalf("did not connect: %v", err)
        }
        defer conn.Close()
        c := pb.NewGreeterClient(conn)

        // Contact the server and print out its response.
        ctx, cancel := context.WithTimeout(context.Background(), time.Second)
        defer cancel()
        r, err := c.SayHello(ctx, &pb.HelloRequest{Name: *name})
        if err != nil {
                log.Fatalf("could not greet: %v", err)
        }
        log.Printf("Greeting: %s", r.GetMessage())
}

middleware/client_interceptor.go

package middleware

import (
        "context"
        "google.golang.org/grpc"
        "time"
)

func defaultContextTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
        var cancel context.CancelFunc
        if _, ok := ctx.Deadline(); !ok {
                defaultTimeout := 3 * time.Second
                ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
        }

        return ctx, cancel
}

func UnaryContextTimeout() grpc.UnaryClientInterceptor {
        return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
                ctx, cancel := defaultContextTimeout(ctx)
                if cancel != nil {
                        defer cancel()
                }

                return invoker(ctx, method, req, resp, cc, opts...)
        }
}

func StreamContextTimeout() grpc.StreamClientInterceptor {
        return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
                ctx, cancel := defaultContextTimeout(ctx)
                if cancel != nil {
                        defer cancel()
                }

                return streamer(ctx, desc, cc, method, opts...)
        }
}

错误码

pkg/errcode/common_error.go
package errcode

var (
        Success          = NewError(0, "成功")
        Fail             = NewError(10000000, "内部错误")
        InvalidParams    = NewError(10000001, "无效参数")
        Unauthorized     = NewError(10000002, "认证错误")
        NotFound         = NewError(10000003, "没有找到")
        Unknown          = NewError(10000004, "未知")
        DeadlineExceeded = NewError(10000005, "超出最后截止期限")
        AccessDenied     = NewError(10000006, "访问被拒绝")
        LimitExceed      = NewError(10000007, "访问限制")
        MethodNotAllowed = NewError(10000008, "不支持该方法")
)
pkg/errcode/errcode.go
package errcode

import "fmt"

type Error struct {
        code int
        msg  string
}

var _codes = map[int]string{}

func NewError(code int, msg string) *Error {
        if _, ok := _codes[code]; ok {
                panic(fmt.Sprintf("错误码 %d 已经存在,请更换一个", code))
        }
        _codes[code] = msg
        return &Error{code: code, msg: msg}
}

func (e *Error) Error() string {
        return fmt.Sprintf("错误码:%d, 错误信息:%s", e.Code(), e.Msg())
}

func (e *Error) Code() int {
        return e.code
}

func (e *Error) Msg() string {
        return e.msg
}
pkg/errcode/rpc_error.go
package errcode

import (
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/status"
        pb "grpc-middleware-demo/proto"
)

type Status struct {
        *status.Status
}

func FromError(err error) *Status {
        s, _ := status.FromError(err)
        return &Status{s}
}

func TogRPCError(err *Error) error {
        s, _ := status.New(ToRPCCode(err.Code()), err.Msg()).WithDetails(&pb.Error{Code: int32(err.Code()), Message: err.Msg()})
        return s.Err()
}

func ToRPCStatus(code int, msg string) *Status {
        s, _ := status.New(ToRPCCode(code), msg).WithDetails(&pb.Error{Code: int32(code), Message: msg})
        return &Status{s}
}

func ToRPCCode(code int) codes.Code {
        var statusCode codes.Code
        switch code {
        case Fail.Code():
                statusCode = codes.Internal
        case InvalidParams.Code():
                statusCode = codes.InvalidArgument
        case Unauthorized.Code():
                statusCode = codes.Unauthenticated
        case AccessDenied.Code():
                statusCode = codes.PermissionDenied
        case DeadlineExceeded.Code():
                statusCode = codes.DeadlineExceeded
        case NotFound.Code():
                statusCode = codes.NotFound
        case LimitExceed.Code():
                statusCode = codes.ResourceExhausted
        case MethodNotAllowed.Code():
                statusCode = codes.Unimplemented
        default:
                statusCode = codes.Unknown
        }

        return statusCode
}

测试

image.png