grpc-ecosystem/go-grpc-middleware已经有了很多常见的的中间件
- grpc_auth
- grpc_zap
- grpc_recovery
- message
- validation
- retries
服务端和客户端都已原生支持同时使用多个中间件, 常用中间件:
服务端
- 记录日志
- 记录错误
- recover
客户端
- 设置超时(当然也可以单独调用接口时,ctx传递)
中间件方法
- 服务端
s := grpc.NewServer(
+ grpc.ChainUnaryInterceptor(
middleware.HelloInterceptor(), //中间件可以写成函数形式, 也可以直接返回
middleware.HelloInterceptor2,
//常用的中间件
middleware.AccessLog,
middleware.ErrorLog,
middleware.Recovery,
),
)
func HelloInterceptor()grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
log.Println("你好")
+ resp, err = handler(ctx, req)
log.Println("再见")
return resp, err
}
}
- 客户端
conn, err := grpc.NewClient(*addr, //新写法取代grpc.Dial
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
+ grpc.WithChainUnaryInterceptor(
middleware.UnaryContextTimeout(),
),
)
func HelloClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
fmt.Println("clientinterceptor前")
+ err := invoker(ctx, method, req, resp, cc, opts...)
fmt.Println("clientinterceptor后")
return err
}
}
目录
go mod init grpc-middleware-demo
├── client
│ └── main.go
├── go.mod
├── go.sum
├── main.go
├── middleware
│ ├── client_interceptor.go
│ └── server_interceptor.go
├── pkg
│ └── errcode
│ ├── common_error.go
│ ├── errcode.go
│ └── rpc_error.go
├── proto
│ ├── helloworld.pb.go
│ ├── helloworld.proto
│ └── helloworld_grpc.pb.go
├── readme.md
├── server
│ └── main.go
└── test.json
proto/helloworld.proto
syntax = "proto3";
import "google/protobuf/any.proto";
package helloworld;
option go_package = ".;pb";
// The greeting service definition.
service Greeter {
// Sends a greeting
rpc SayHello(HelloRequest) returns (HelloReply) {}
}
// The request message containing the user's name.
message HelloRequest {
string name = 1;
}
// The response message containing the greetings
message HelloReply {
string message = 1;
}
message Error {
int32 code = 1;
string message = 2;
google.protobuf.Any detail = 3;
}
//protoc --go_out=./proto --go-grpc_out=./proto ./proto/*.proto
server/main.go
package main
import (
"context"
"flag"
"fmt"
"grpc-middleware-demo/middleware"
pb "grpc-middleware-demo/proto"
"log"
"net"
"google.golang.org/grpc"
)
var (
port = flag.Int("port", 50051, "The server port")
)
// server is used to implement helloworld.GreeterServer.
type server struct {
pb.UnimplementedGreeterServer
}
// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
panic("implement me")
log.Printf("Received: %v", in.GetName())
return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil
}
func main() {
flag.Parse()
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(
middleware.HelloInterceptor,
middleware.HelloInterceptor2,
//常用的中间件
middleware.AccessLog,
middleware.ErrorLog,
middleware.Recovery,
),
)
pb.RegisterGreeterServer(s, &server{})
log.Printf("server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
middleware/server_interceptor.go
package middleware
import (
"context"
"google.golang.org/grpc"
"grpc-middleware-demo/pkg/errcode"
"log"
"runtime/debug"
"time"
)
func HelloInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
log.Println("你好")
resp, err := handler(ctx, req)
log.Println("再见")
return resp, err
}
func HelloInterceptor2(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
log.Println("你好2")
resp, err := handler(ctx, req)
log.Println("再见2")
return resp, err
}
func AccessLog(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
requestLog := "access request log: method: %s, begin_time: %d, request: %v"
beginTime := time.Now().Local().Unix()
log.Printf(requestLog, info.FullMethod, beginTime, req)
resp, err := handler(ctx, req)
responseLog := "access response log: method: %s, begin_time: %d, end_time: %d, response: %v"
endTime := time.Now().Local().Unix()
log.Printf(responseLog, info.FullMethod, beginTime, endTime, resp)
return resp, err
}
func ErrorLog(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
resp, err := handler(ctx, req)
if err != nil {
errLog := "error log: method: %s, code: %v, message: %v, details: %v"
s := errcode.FromError(err)
log.Printf(errLog, info.FullMethod, s.Code(), s.Err().Error(), s.Details())
}
return resp, err
}
func Recovery(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
defer func() {
if e := recover(); e != nil {
recoveryLog := "recovery log: method: %s, message: %v, stack: %s"
log.Printf(recoveryLog, info.FullMethod, e, string(debug.Stack()[:]))
}
}()
return handler(ctx, req)
}
client/main.go
package main
import (
"context"
"flag"
"grpc-middleware-demo/middleware"
pb "grpc-middleware-demo/proto"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
defaultName = "world"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
name = flag.String("name", defaultName, "Name to greet")
)
func main() {
flag.Parse()
// Set up a connection to the server.
conn, err := grpc.NewClient(*addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
grpc.WithChainUnaryInterceptor(
middleware.UnaryContextTimeout(),
))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
c := pb.NewGreeterClient(conn)
// Contact the server and print out its response.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := c.SayHello(ctx, &pb.HelloRequest{Name: *name})
if err != nil {
log.Fatalf("could not greet: %v", err)
}
log.Printf("Greeting: %s", r.GetMessage())
}
middleware/client_interceptor.go
package middleware
import (
"context"
"google.golang.org/grpc"
"time"
)
func defaultContextTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
var cancel context.CancelFunc
if _, ok := ctx.Deadline(); !ok {
defaultTimeout := 3 * time.Second
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
}
return ctx, cancel
}
func UnaryContextTimeout() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, cancel := defaultContextTimeout(ctx)
if cancel != nil {
defer cancel()
}
return invoker(ctx, method, req, resp, cc, opts...)
}
}
func StreamContextTimeout() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx, cancel := defaultContextTimeout(ctx)
if cancel != nil {
defer cancel()
}
return streamer(ctx, desc, cc, method, opts...)
}
}
错误码
pkg/errcode/common_error.go
package errcode
var (
Success = NewError(0, "成功")
Fail = NewError(10000000, "内部错误")
InvalidParams = NewError(10000001, "无效参数")
Unauthorized = NewError(10000002, "认证错误")
NotFound = NewError(10000003, "没有找到")
Unknown = NewError(10000004, "未知")
DeadlineExceeded = NewError(10000005, "超出最后截止期限")
AccessDenied = NewError(10000006, "访问被拒绝")
LimitExceed = NewError(10000007, "访问限制")
MethodNotAllowed = NewError(10000008, "不支持该方法")
)
pkg/errcode/errcode.go
package errcode
import "fmt"
type Error struct {
code int
msg string
}
var _codes = map[int]string{}
func NewError(code int, msg string) *Error {
if _, ok := _codes[code]; ok {
panic(fmt.Sprintf("错误码 %d 已经存在,请更换一个", code))
}
_codes[code] = msg
return &Error{code: code, msg: msg}
}
func (e *Error) Error() string {
return fmt.Sprintf("错误码:%d, 错误信息:%s", e.Code(), e.Msg())
}
func (e *Error) Code() int {
return e.code
}
func (e *Error) Msg() string {
return e.msg
}
pkg/errcode/rpc_error.go
package errcode
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "grpc-middleware-demo/proto"
)
type Status struct {
*status.Status
}
func FromError(err error) *Status {
s, _ := status.FromError(err)
return &Status{s}
}
func TogRPCError(err *Error) error {
s, _ := status.New(ToRPCCode(err.Code()), err.Msg()).WithDetails(&pb.Error{Code: int32(err.Code()), Message: err.Msg()})
return s.Err()
}
func ToRPCStatus(code int, msg string) *Status {
s, _ := status.New(ToRPCCode(code), msg).WithDetails(&pb.Error{Code: int32(code), Message: msg})
return &Status{s}
}
func ToRPCCode(code int) codes.Code {
var statusCode codes.Code
switch code {
case Fail.Code():
statusCode = codes.Internal
case InvalidParams.Code():
statusCode = codes.InvalidArgument
case Unauthorized.Code():
statusCode = codes.Unauthenticated
case AccessDenied.Code():
statusCode = codes.PermissionDenied
case DeadlineExceeded.Code():
statusCode = codes.DeadlineExceeded
case NotFound.Code():
statusCode = codes.NotFound
case LimitExceed.Code():
statusCode = codes.ResourceExhausted
case MethodNotAllowed.Code():
statusCode = codes.Unimplemented
default:
statusCode = codes.Unknown
}
return statusCode
}