grpc-go服务端源码分析

1,640 阅读4分钟

代码讲解基于v1.37.0版本的greeter_server示例。

我们先看一段代码,

const (
	port = ":50051"
)

// server is used to implement helloworld.GreeterServer.
type server struct {
	pb.UnimplementedGreeterServer
}

// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
	log.Printf("Received: %v", in.GetName())
	time.Sleep(time.Second * 5)
	return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil
}

func main() {
	lis, err := net.Listen("tcp", port)
	if err != nil {
		log.Fatalf("failed to listen: %v", err)
	}
	s := grpc.NewServer()
	pb.RegisterGreeterServer(s, &server{})
	if err := s.Serve(lis); err != nil {
		log.Fatalf("failed to serve: %v", err)
	}
}

首先使用net.Listen创建一个Listener的对象,Listener是通用的面向流协议的网络监听者,其定义如下:

// A Listener is a generic network listener for stream-oriented protocols.
//
// Multiple goroutines may invoke methods on a Listener simultaneously.
type Listener interface {
	// Accept waits for and returns the next connection to the listener.
	Accept() (Conn, error)

	// Close closes the listener.
	// Any blocked Accept operations will be unblocked and return errors.
	Close() error

	// Addr returns the listener's network address.
	Addr() Addr
}

使用s := grpc.NewServer()初始化一个Server,其结构如下:

type Server struct {
	opts serverOptions

	mu       sync.Mutex // guards following
	lis      map[net.Listener]bool
	conns    map[transport.ServerTransport]bool
	serve    bool // 服务是否启动 
	drain    bool
	cv       *sync.Cond              // signaled when connections close for GracefulStop,当连接优雅关闭后发送信号
	services map[string]*serviceInfo // service name -> service info,服务的名称到服务基本信息的映射,主要是各个方法等
	events   trace.EventLog

	quit               *grpcsync.Event
	done               *grpcsync.Event
	channelzRemoveOnce sync.Once
	serveWG            sync.WaitGroup // counts active Serve goroutines for GracefulStop, 统计活跃的goroutine数量便于优雅退出

	channelzID int64 // channelz unique identification number
	czData     *channelzData

	serverWorkerChannels []chan *serverWorkerData
}

具体初始化逻辑在NewServer中,能够接收多个ServerOption参数,用于后续组装Server结构。

func NewServer(opt ...ServerOption) *Server {
	opts := defaultServerOptions // 默认配置 
	for _, o := range opt { 
		o.apply(&opts) // 将服务额外配置放到opts中 
	}
	s := &Server{
		lis:      make(map[net.Listener]bool),
		opts:     opts,
		conns:    make(map[transport.ServerTransport]bool),
		services: make(map[string]*serviceInfo),
		quit:     grpcsync.NewEvent(), // 创建空结构,event包含chan事件,用于通信
		done:     grpcsync.NewEvent(), // 创建空结构,event包含chan事件,用于通信
		czData:   new(channelzData),
	}
	
	// 处理中间件
	chainUnaryServerInterceptors(s)
	chainStreamServerInterceptors(s)
	s.cv = sync.NewCond(&s.mu)
	
	// 是否链路追踪
	if EnableTracing {
		_, file, line, _ := runtime.Caller(1)
		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
	}

	if s.opts.numServerWorkers > 0 {
		s.initServerWorkers()
	}

	if channelz.IsOn() {
		s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
	}
	return s
}

回到示例中,使用pb.RegisterGreeterServer(s, &server{})将业务逻辑代码挂载到刚创建的grpc.Server对象上。

&server{}server结构体需要实现proto定义的方法,也就是SayHello方法。

service Greeter {
  // Sends a greeting
  rpc SayHello (HelloRequest) returns (HelloReply) {}
}

换个角度来说,需要实现代码生成的pb.go文件中的GreeterServer接口,如下面代码:

type GreeterServer interface {
	// Sends a greeting
	SayHello(context.Context, *HelloRequest) (*HelloReply, error)
	mustEmbedUnimplementedGreeterServer()
}

pb.RegisterGreeterServer(s, &server{})的调用流程如下,

// https://github.com/grpc/grpc-go/blob/v1.37.0/examples/helloworld/helloworld/helloworld.pb.go
func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) {
	s.RegisterService(&Greeter_ServiceDesc, srv)
}

// 来自https://github.com/grpc/grpc-go/blob/v1.37.0/server.go#L578
// RegisterService registers a service and its implementation to the gRPC
// server. It is called from the IDL generated code. This must be called before
// invoking Serve. If ss is non-nil (for legacy code), its type is checked to
// ensure it implements sd.HandlerType.
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
	if ss != nil {
		ht := reflect.TypeOf(sd.HandlerType).Elem()
		st := reflect.TypeOf(ss)
		if !st.Implements(ht) {
			logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
		}
	}
	s.register(sd, ss)
}

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.printf("RegisterService(%q)", sd.ServiceName)
	if s.serve {
		logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
	}
	if _, ok := s.services[sd.ServiceName]; ok {
		logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
	}
	info := &serviceInfo{
		serviceImpl: ss, // proto方法实现
		methods:     make(map[string]*MethodDesc),
		streams:     make(map[string]*StreamDesc),
		mdata:       sd.Metadata,
	}

	// 将基本方法说明信息设置serviceInfo中
	for i := range sd.Methods {
		d := &sd.Methods[i]
		info.methods[d.MethodName] = d
	}
	// 将流方法说明信息设置serviceInfo中
	for i := range sd.Streams {
		d := &sd.Streams[i]
		info.streams[d.StreamName] = d
	}
	// 设置service名称到service info的关联
	s.services[sd.ServiceName] = info
}

关于s.RegisterService(&Greeter_ServiceDesc, srv)说明,

// Greeter_ServiceDesc is the grpc.ServiceDesc for Greeter service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Greeter_ServiceDesc = grpc.ServiceDesc{
	ServiceName: "helloworld.Greeter",  // 对应pb文件中package.service,其中helloworld为package,Greeter为service
	HandlerType: (*GreeterServer)(nil), // HandlerType是指向GreeterServer的指针,用于校验接口实现使用满足接口要求
	Methods: []grpc.MethodDesc{ // 方法列表
		{
			MethodName: "SayHello",                // 方法名称
			Handler:    _Greeter_SayHello_Handler, // 具体执行函数
		},
	},
	Streams:  []grpc.StreamDesc{},                               // 流处理,很少使用
	Metadata: "examples/helloworld/helloworld/helloworld.proto", // proto文件
}

调用s.Serve(lis)启动服务端,在这之前都是准备阶段,

func (s *Server) Serve(lis net.Listener) error {
	// ...
	// 删除了一些不重要的逻辑 
	var tempDelay time.Duration // how long to sleep on accept failure

	for {
		rawConn, err := lis.Accept()
		// 如果报错,减少吞吐
		if err != nil {
			if ne, ok := err.(interface {
				Temporary() bool
			}); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}
				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}
				s.mu.Lock()
				s.printf("Accept error: %v; retrying in %v", err, tempDelay)
				s.mu.Unlock()
				timer := time.NewTimer(tempDelay)
				select {
				case <-timer.C:
				case <-s.quit.Done():
					timer.Stop()
					return nil
				}
				continue
			}
			s.mu.Lock()
			s.printf("done serving; Accept = %v", err)
			s.mu.Unlock()

			if s.quit.HasFired() {
				return nil
			}
			return err
		}
		// 设置成功连接初始值
		tempDelay = 0
		// Start a new goroutine to deal with rawConn so we don't stall this Accept
		// loop goroutine.
		//
		// Make sure we account for the goroutine so GracefulStop doesn't nil out
		// s.conns before this conn can be added.
		s.serveWG.Add(1)
		go func() {
			s.handleRawConn(rawConn) // 处理原始连接
			s.serveWG.Done()
		}()
	}
}

继续看handleRawConn

func (s *Server) handleRawConn(rawConn net.Conn) {
	// ...
	// 删除了不重要的代码 
	// ... 
	// Finish handshaking (HTTP2)
	// 创建http2的传输层
	st := s.newHTTP2Transport(conn, authInfo)
	if st == nil {
		return
	}

	rawConn.SetDeadline(time.Time{})
	if !s.addConn(st) {
		return
	}
	go func() {
		s.serveStreams(st) // 处理流数据
		s.removeConn(st)   // 删除当前连接,广播一下
	}()
}

继续看serveStreams

func (s *Server) serveStreams(st transport.ServerTransport) {
	defer st.Close()
	var wg sync.WaitGroup

	var roundRobinCounter uint32
	// 使用第一个参数,处理流入数据;
	// 使用第二个参数,处理链路追踪;
	// stream是传输层的一个RPC
	st.HandleStreams(func(stream *transport.Stream) {
		wg.Add(1)
		if s.opts.numServerWorkers > 0 {
			data := &serverWorkerData{st: st, wg: &wg, stream: stream}
			select {
			// 有点复杂,没看明白
			case s.serverWorkerChannels[atomic.AddUint32(&roundRobinCounter, 1)%s.opts.numServerWorkers] <- data:
			default:
				// If all stream workers are busy, fallback to the default code path.
				// 如果所有的流workers都处于繁忙的状态,降级为默认的code路径(到底是啥)
				go func() {
					s.handleStream(st, stream, s.traceInfo(st, stream))
					wg.Done()
				}()
			}
		} else {
			go func() {
				defer wg.Done()
				s.handleStream(st, stream, s.traceInfo(st, stream))
			}()
		}
	}, func(ctx context.Context, method string) context.Context {
		if !EnableTracing {
			return ctx
		}
		tr := trace.New("grpc.Recv."+methodFamily(method), method)
		return trace.NewContext(ctx, tr)
	})
	wg.Wait()
}

继续看handleStream

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
	sm := stream.Method()
	if sm != "" && sm[0] == '/' {
		sm = sm[1:]
	}
	pos := strings.LastIndex(sm, "/")
	if pos == -1 {
		if trInfo != nil {
			trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
			trInfo.tr.SetError()
		}
		errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
		if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
			if trInfo != nil {
				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
				trInfo.tr.SetError()
			}
			channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
		}
		if trInfo != nil {
			trInfo.tr.Finish()
		}
		return
	}
	service := sm[:pos]
	method := sm[pos+1:]

	// 之前的register函数将pb数据信息注册到了Server结构中,这里取出来使用
	srv, knownService := s.services[service]
	if knownService { // 如果服务是存在,这里判断方法
		if md, ok := srv.methods[method]; ok {
			// 单次处理
			s.processUnaryRPC(t, stream, srv, md, trInfo)
			return
		}
		if sd, ok := srv.streams[method]; ok {
			// 流式处理(使用较少,这里就不分析了)
			// srv是具体实现,非常重要
			s.processStreamingRPC(t, stream, srv, sd, trInfo)
			return
		}
	}
	// Unknown service, or known server unknown method.
	// unknown默认处理存在的情况
	if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
		s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
		return
	}
	var errDesc string
	if !knownService {
		errDesc = fmt.Sprintf("unknown service %v", service)
	} else {
		errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
	}
	if trInfo != nil {
		trInfo.tr.LazyPrintf("%s", errDesc)
		trInfo.tr.SetError()
	}
	// 响应,写回流中
	if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
		if trInfo != nil {
			trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
			trInfo.tr.SetError()
		}
		channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
	}
	if trInfo != nil {
		trInfo.tr.Finish()
	}
}

继续看processUnaryRPC

func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
    
    // ...
    // 删除一些不重要的代码  
	
	// ... 
	// 删除加解密逻辑 

	var payInfo *payloadInfo
	if sh != nil || binlog != nil {
		payInfo = &payloadInfo{}
	}
	// 内部使用recvMsg函数接受消息 
	d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
	if err != nil {
		if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
			channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status %v", e)
		}
		return err
	}
	
	// ... 
	
	// 处理RPC状态 
	df := func(v interface{}) error {
		// 这里会将请求参数赋值到v中,在_Greeter_SayHello_Handler函数中,会使用in := new(HelloRequest)初始化并使用in来接受数据
		if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
			return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
		}
		if sh != nil {
			// 处理RPC状态(为啥要处理状态呢)
			sh.HandleRPC(stream.Context(), &stats.InPayload{
				RecvTime:   time.Now(),
				Payload:    v,
				WireLength: payInfo.wireLength + headerLen,
				Data:       d,
				Length:     len(d),
			})
		}
		if binlog != nil {
			binlog.Log(&binarylog.ClientMessage{
				Message: d,
			})
		}
		if trInfo != nil {
			trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
		}
		return nil
	}
	
	// 调用方法
	// 这里查看pb文件中实现,对于helloworld例子,调用的是_Greeter_SayHello_Handler方法
	// info.serviceImpl标示具体的实现,在_Greeter_SayHello_Handler中srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))调用了真正的实现方法
	reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt)

	// 响应数据
	if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
	    // ...
	}
	
	// ... 
}

返回去看看serveStreams方法中st.HandleStreams的实现,代码在./internal/transport/http2_server.go中,主要是http2通信相关逻辑,这部分后续补充。

func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
	defer close(t.readerDone)
	for {
		t.controlBuf.throttle()
		frame, err := t.framer.fr.ReadFrame()
		atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
		if err != nil {
			if se, ok := err.(http2.StreamError); ok {
				if logger.V(logLevel) {
					logger.Warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
				}
				t.mu.Lock()
				s := t.activeStreams[se.StreamID]
				t.mu.Unlock()
				if s != nil {
					t.closeStream(s, true, se.Code, false)
				} else {
					t.controlBuf.put(&cleanupStream{
						streamID: se.StreamID,
						rst:      true,
						rstCode:  se.Code,
						onWrite:  func() {},
					})
				}
				continue
			}
			if err == io.EOF || err == io.ErrUnexpectedEOF {
				t.Close()
				return
			}
			if logger.V(logLevel) {
				logger.Warningf("transport: http2Server.HandleStreams failed to read frame: %v", err)
			}
			t.Close()
			return
		}
		switch frame := frame.(type) {
		case *http2.MetaHeadersFrame:
		    // handle函数使用在这里
			if t.operateHeaders(frame, handle, traceCtx) {
				t.Close()
				break
			}
		case *http2.DataFrame:
		    // 处理消息内容 
			t.handleData(frame)
		case *http2.RSTStreamFrame:
			t.handleRSTStream(frame)
		case *http2.SettingsFrame:
			t.handleSettings(frame)
		case *http2.PingFrame:
			t.handlePing(frame)
		case *http2.WindowUpdateFrame:
			t.handleWindowUpdate(frame)
		case *http2.GoAwayFrame:
			// TODO: Handle GoAway from the client appropriately.
		default:
			if logger.V(logLevel) {
				logger.Errorf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
			}
		}
	}
}

processUnaryRPC方法中,使用sendResponse来发送响应,使用err = t.Write(stream, hdr, payload, opts)来调用底层http2_server中方法实现。

看看sendResponse方法的实现,

func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
	data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
	if err != nil {
		channelz.Error(logger, s.channelzID, "grpc: server failed to encode response: ", err)
		return err
	}
	compData, err := compress(data, cp, comp)
	if err != nil {
		channelz.Error(logger, s.channelzID, "grpc: server failed to compress response: ", err)
		return err
	}
	hdr, payload := msgHeader(data, compData)
	// TODO(dfawley): should we be checking len(data) instead?
	if len(payload) > s.opts.maxSendMessageSize {
		return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
	}
	err = t.Write(stream, hdr, payload, opts)
	if err == nil && s.opts.statsHandler != nil {
		s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
	}
	return err
}

gRPC的server端主逻辑流程还是比较清楚、简单的,底层通信协议是http2协议,具体的实现代码在google.golang.org/grpc/internal/transport包中。transport.go文件中的ServerTransport interface{}定义了服务端http2结构体的通用方法,具体的实现原理会在写完client端的源码实现之后统一讲,敬请期待。