gRPC-go源码解析-server篇

1,157 阅读7分钟

有关gRPC是什么和其特性这里不详细展开,相关的文章特别多,自行搜索即可,这里直接开始对官方示例hello world项目源码的阅读。

hello-world例子中的server/main.go中我们看到生成一个server的代码段如下:

lis, err := net.Listen("tcp", port)
if err != nil {
	log.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer()
pb.RegisterGreeterServer(s, &server{})
if err := s.Serve(lis); err != nil {
	log.Fatalf("failed to serve: %v", err)
}

从这段代码逻辑可以看出创建一个server大致分为如下几步:

  • 创建一个新的server(grpc.NewServer())
  • server进行注册
  • 调用方法监听端口

创建server

在这段逻辑里我们通过进入NewServer方法,可以看到如下代码:

func NewServer(opt ...ServerOption) *Server {
	opts := defaultServerOptions // 1.
	for _, o := range opt {
		o.apply(&opts)
	}
	s := &Server{ // 2.
		lis:    make(map[net.Listener]bool),
		opts:   opts,
		conns:  make(map[transport.ServerTransport]bool),
		m:      make(map[string]*service),
		quit:   grpcsync.NewEvent(),
		done:   grpcsync.NewEvent(),
		czData: new(channelzData),
	}
	chainUnaryServerInterceptors(s) // 3.
	chainStreamServerInterceptors(s)
	s.cv = sync.NewCond(&s.mu)
	if EnableTracing {
		_, file, line, _ := runtime.Caller(1)
		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
	}

	if channelz.IsOn() {
		s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
	}
	return s
}

接下来我们逐行进行分析,首先该方法的入参是一些服务器的可选参数,ServerOption本身是一个接口,里面有一个apply方法

type ServerOption interface {
    apply(*serverOptions)
}

根据官方给出的注释,我们可以看到这个接口里的方法主要是对server设置一些可选参数,比如codec,或者是参数的生命周期等。而serverOptions这个结构体也恰恰定义的是这些服务器的参数。

回到注释1的位置,那么首先我们的程序就是设置了必要的服务器参数,具体的内容如下:

var defaultServerOptions = serverOptions{
	maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
	maxSendMessageSize:    defaultServerMaxSendMessageSize,
	connectionTimeout:     120 * time.Second,
	writeBufferSize:       defaultWriteBufSize,
	readBufferSize:        defaultReadBufSize,
}

比如说默认的最大可接收和发送的消息大小,连接的超时时间和Buffer的大小。

然后进入的for循环就是将服务器参数都设置到我们的optionServers这个结构体中。

再向下走进入到注释2的位置,就是对我们的Server结构体做了设置,Server的结构如下:

type Server struct {
    opts serverOptions

    mu     sync.Mutex // 互斥锁
    lis    map[net.Listener]bool // listener map
    conns  map[transport.ServerTransport]bool //connextions map
    serve  bool // 是否在处理请求的状态位
    drain  bool
    cv     *sync.Cond          // signaled when connections close for GracefulStop
    m      map[string]*service // service name -> service info
    events trace.EventLog

    quit               *grpcsync.Event
    done               *grpcsync.Event
    channelzRemoveOnce sync.Once
    serveWG            sync.WaitGroup // counts active Serve goroutines for GracefulStop

    channelzID int64 // channelz unique identification number
    czData     *channelzData
}

可以看到,这里面比较重要的有3个map,分别存储的是listener的信息,connection的信息以及提供的service的信息。而其他的字段主要是提供了某些服务器的状态信息或者并发控制的功能。

那么首先对于存储listener的map而言,listener本质是一个接口,里面提供了Accept(),Close(),Addr()三个方法,分别提供是服务器准备进行连接,和关闭listener,以及返回listener网络地址的功能。

对存储service的map而言,service的结构如下:

type service struct {
    server interface{} // the server for service methods
    md     map[string]*MethodDesc
    sd     map[string]*StreamDesc
    mdata  interface{}
}

server接口里存放的是该服务器所提供的service方法,而下面两个map则是存储了method和stream流的服务信息。

type MethodDesc struct {
    MethodName string
    Handler    methodHandler
}

type StreamDesc struct {
    StreamName string
    Handler    StreamHandler

    // At least one of these is true.
    ServerStreams bool
    ClientStreams bool
}

其中每一个struct里面都有一个handler来对调用的方法进行处理。

回到主流程的注释3的位置,可以看到有两个拦截器,第一个拦截器主要是将我们定义的server端的拦截器最终都用一个拦截器链进行规整:(具体内容见注释)

func chainUnaryServerInterceptors(s *Server) {
	// Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will
	// be executed before any other chained interceptors.
    // 这几步主要是检查一下拦截器的个数,如果方法unary拦截器数组不是空的话,就要把这些拦截器继续添加到我们的拦截器链上
	interceptors := s.opts.chainUnaryInts // 
	if s.opts.unaryInt != nil {
		interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...)
	}

	var chainedInt UnaryServerInterceptor
	if len(interceptors) == 0 {
		chainedInt = nil
	} else if len(interceptors) == 1 {
		chainedInt = interceptors[0]
	} else {
		chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
            // 如果拦截器数量大于1,那么会递归的生成一个拦截器链
			return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
		}
	}

	s.opts.unaryInt = chainedInt
}

其中,getChainUnaryHandler方法的逻辑需要看一下:

func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
	if curr == len(interceptors)-1 {
		return finalHandler
	}

	return func(ctx context.Context, req interface{}) (interface{}, error) {
		return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
	}
}

在本段逻辑里,首先判断当前的curr指针是不是整个拦截器链的末尾,如果是的话就会返回最尾端的handler,否则就不断的递归,从而生成一条拦截器链。

chainStreamServerInterceptorschainUnaryServerInterceptors方法类似,这里不再赘述。

server注册

继续回到主线,在完成了相应的server设置之后,就要对server进行注册,跟随

pb.RegisterGreeterServer(s, &server{})方法一路进入到RegisterService方法,即:

func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
	ht := reflect.TypeOf(sd.HandlerType).Elem()
	st := reflect.TypeOf(ss)
	if !st.Implements(ht) {
		grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
	}
	s.register(sd, ss)
}

首先就是通过反射来获取server当中handler链中每个handler的类型,再获自己定义的service的类型,紧接着就是判断一下是否自定义的service的类型实现了我们要求的server里面handler的类型,如果是的话就进入register的逻辑。

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.printf("RegisterService(%q)", sd.ServiceName)
	if s.serve {
		grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
	}
	if _, ok := s.m[sd.ServiceName]; ok {
		grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
	}
	srv := &service{
		server: ss,
		md:     make(map[string]*MethodDesc),
		sd:     make(map[string]*StreamDesc),
		mdata:  sd.Metadata,
	}
	for i := range sd.Methods {
		d := &sd.Methods[i]
		srv.md[d.MethodName] = d
	}
	for i := range sd.Streams {
		d := &sd.Streams[i]
		srv.sd[d.StreamName] = d
	}
	s.m[sd.ServiceName] = srv
}

首先在这段逻辑里最开始判断了下是不是先注册了方法后启动的server,和是不是重复注册了service,如果都没有的话,则按照方法名为 key,将方法注入到 server 的 service map 中。看到这里我们其实可以预测一下,server 不同 rpc 请求的处理,也是根据 service 中不同的 serviceName 去 service map 中取出不同的 handler 进行处理。

启动Serve

对于通常的C/S架构的通信,普遍的实现都是server端不断地嗅探本端口是不是有连接请求,如果有client端进行连接,那么就握手建立连接,然后client通过调用相应的方法和参数对server的service进行调用,这个请求就对打到server端的handler处来进行处理。所以,对 server 端来说,主要是了解其如何实现监听,如何为请求分配不同的 handler 和 回写响应数据。来看一下具体的Serve方法

func (s *Server) Serve(lis net.Listener) error {
	s.mu.Lock()
	s.printf("serving")
	s.serve = true
	if s.lis == nil {
		// Serve called after Stop or GracefulStop.
		s.mu.Unlock()
		lis.Close()
		return ErrServerStopped
	}

	s.serveWG.Add(1)
	defer func() {
		s.serveWG.Done()
		if s.quit.HasFired() {
			// Stop or GracefulStop called; block until done and return nil.
			<-s.done.Done()
		}
	}()

	ls := &listenSocket{Listener: lis}
	s.lis[ls] = true

	if channelz.IsOn() {
		ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String())
	}
	s.mu.Unlock()

	defer func() {
		s.mu.Lock()
		if s.lis != nil && s.lis[ls] {
			ls.Close()
			delete(s.lis, ls)
		}
		s.mu.Unlock()
	}()

	var tempDelay time.Duration // how long to sleep on accept failure

	for {
		rawConn, err := lis.Accept() // 4
		if err != nil {
			if ne, ok := err.(interface {
				Temporary() bool
			}); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}
				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}
				s.mu.Lock()
				s.printf("Accept error: %v; retrying in %v", err, tempDelay)
				s.mu.Unlock()
				timer := time.NewTimer(tempDelay)
				select {
				case <-timer.C:
				case <-s.quit.Done():
					timer.Stop()
					return nil
				}
				continue
			}
			s.mu.Lock()
			s.printf("done serving; Accept = %v", err)
			s.mu.Unlock()

			if s.quit.HasFired() {
				return nil
			}
			return err
		}
		tempDelay = 0
		// Start a new goroutine to deal with rawConn so we don't stall this Accept
		// loop goroutine.
		//
		// Make sure we account for the goroutine so GracefulStop doesn't nil out
		// s.conns before this conn can be added.
		s.serveWG.Add(1) // 5.
		go func() {
			s.handleRawConn(rawConn)
			s.serveWG.Done()
		}()
	}
}

从注释4开始,我们看到程序进入循环然后监听对应的端口。紧接在看到注释5我们发现程序起了一个goroutine去调用handleRawConn方法,进一步跟踪进去:

func (s *Server) handleRawConn(rawConn net.Conn) {
    // ... 
    conn, authInfo, err := s.useTransportAuthenticator(rawConn)
    // ...
    // Finish handshaking (HTTP2)
    st := s.newHTTP2Transport(conn, authInfo)
    if st == nil {
        return
    }
    // ...
    go func() {
        s.serveStreams(st)
        s.removeConn(st)
    }()
}

我将不重要的地方进行了省略,可以看到在本方法内,确实是通过建立HTTP2的握手来实现了连接的建立,然后程序又开了一个goroutine来调用了serveStreams方法:

func (s *Server) serveStreams(st transport.ServerTransport) {
    defer st.Close()
    var wg sync.WaitGroup
    st.HandleStreams(func(stream *transport.Stream) {
        wg.Add(1)
        go func() {
            defer wg.Done()
            s.handleStream(st, stream, s.traceInfo(st, stream))
        }()
    }, func(ctx context.Context, method string) context.Context {
        if !EnableTracing {
            return ctx
        }
        tr := trace.New("grpc.Recv."+methodFamily(method), method)
        return trace.NewContext(ctx, tr)
    })
    wg.Wait()
}

在这里依然是调用了handleStreams方法:

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
    sm := stream.Method()
    ...
    service := sm[:pos]
    method := sm[pos+1:]
    srv, knownService := s.m[service]
    if knownService {
        if md, ok := srv.md[method]; ok {
            s.processUnaryRPC(t, stream, srv, md, trInfo)
            return
        }
        if sd, ok := srv.sd[method]; ok {
            s.processStreamingRPC(t, stream, srv, sd, trInfo)
            return
        }
    }
    ...
}

在这里,果然,程序根据 serviceName 去 server 中的 service map,也就是 m 这个字段里去取出 handler 进行处理。我们 hello world 这个 demo 的请求不涉及到 stream ,所以直接取出 handler ,然后传给 processUnaryRPC 这个方法进行处理。

所以我们进一步跟进processUnaryRPC方法:

func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
    // ...
    sh := s.opts.statsHandler
    if sh != nil {
        beginTime := time.Now()
        begin := &stats.Begin{
            BeginTime: beginTime,
        }
        sh.HandleRPC(stream.Context(), begin)
        defer func() {
            end := &stats.End{
                BeginTime: beginTime,
                EndTime:   time.Now(),
            }
            if err != nil && err != io.EOF {
                end.Error = toRPCErr(err)
            }
            sh.HandleRPC(stream.Context(), end)
        }()
    }
    // ...
    if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
        if err == io.EOF {
            // The entire stream is done (for unary RPC only).
            return err
        }
        if s, ok := status.FromError(err); ok {
            if e := t.WriteStatus(stream, s); e != nil {
                grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
            }
        } else {
            switch st := err.(type) {
            case transport.ConnectionError:
                // Nothing to do here.
            default:
                panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
            }
        }
        if binlog != nil {
            h, _ := stream.Header()
            binlog.Log(&binarylog.ServerHeader{
                Header: h,
            })
            binlog.Log(&binarylog.ServerTrailer{
                Trailer: stream.Trailer(),
                Err:     appErr,
            })
        }
        return err
    }
    // ...
}

我们发现了对handler方法的调用和response的回写。

那么gRPC server端部分源码到这里就结束了,有关其他的stream方式放到以后的文章进行阅读。