深挖 Go 标准库 net/http 源码,来看看优雅退出到底是如何实现的?

351 阅读11分钟

Shutdown 源码

我在《Go 标准库 net/http 如何实现优雅退出?》 一文中讲解了 net/http 实现的 HTTP Server 如何优雅退出,本文就来讲解下优雅退出源码是如何实现的。

Shutdown 方法源码如下:

github.com/golang/go/b…

// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the [Server]'s underlying Listener(s).
//
// When Shutdown is called, [Serve], [ListenAndServe], and
// [ListenAndServeTLS] immediately return [ErrServerClosed]. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not attempt to close nor wait for hijacked
// connections such as WebSockets. The caller of Shutdown should
// separately notify such long-lived connections of shutdown and wait
// for them to close, if desired. See [Server.RegisterOnShutdown] for a way to
// register shutdown notification functions.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
	srv.inShutdown.Store(true)

	srv.mu.Lock()
	lnerr := srv.closeListenersLocked()
	for _, f := range srv.onShutdown {
		go f()
	}
	srv.mu.Unlock()
	srv.listenerGroup.Wait()

	pollIntervalBase := time.Millisecond
	nextPollInterval := func() time.Duration {
		// Add 10% jitter.
		interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
		// Double and clamp for next time.
		pollIntervalBase *= 2
		if pollIntervalBase > shutdownPollIntervalMax {
			pollIntervalBase = shutdownPollIntervalMax
		}
		return interval
	}

	timer := time.NewTimer(nextPollInterval())
	defer timer.Stop()
	for {
		if srv.closeIdleConns() {
			return lnerr
		}
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-timer.C:
			timer.Reset(nextPollInterval())
		}
	}
}

首先 Shutdown 方法注释写的非常清晰:

Shutdown 会优雅地关闭服务器,而不会中断任何活动的连接。它的工作原理是先关闭所有已打开的监听器(listeners),然后关闭所有空闲的连接,并无限期地等待所有连接变为空闲状态后再关闭服务器。如果在关闭完成之前,传入的上下文(context)过期,Shutdown 会返回上下文的错误,否则它将返回关闭服务器底层监听器时所产生的任何错误。

当调用 Shutdown 时,[Serve]、[ListenAndServe] 和 [ListenAndServeTLS] 会立即返回 [ErrServerClosed] 错误。请确保程序不会直接退出,而是等待 Shutdown 返回后再退出。

Shutdown 不会尝试关闭或等待被劫持的连接(例如 WebSocket)。Shutdown 的调用者应单独通知这些长时间存在的连接关于关闭的信息,并根据需要等待它们关闭。可以参考 [Server.RegisterOnShutdown] 来注册关闭通知函数。

一旦在服务器上调用了 Shutdown,它将无法再次使用;之后对 Serve 等方法的调用将返回 ErrServerClosed 错误。

通过这段注释,我们就能对 Shutdown 方法执行流程有个大概理解。

接着我们来从上到下依次分析下 Shutdown 源码。

第一行代码如下:

srv.inShutdown.Store(true)

Shutdown 首先将 inShutdown 标记为 trueinShutdownatomic.Bool 类型,它用来标记服务器是否正在关闭。

这里使用了 atomic 来保证操作的原子性,以免其他方法读取到错误的 inShutdown 标志位,发生错误。避免 HTTP Server 进程已经开始处理结束逻辑,还会有新的请求进入到 srv.Serve 方法。

接着 Shutdown 会关闭监听的端口:

srv.mu.Lock()
lnerr := srv.closeListenersLocked()
for _, f := range srv.onShutdown {
    go f()
}
srv.mu.Unlock()

代码中的 srv.closeListenersLocked() 就是在关闭所有的监听器(listeners)。

方法定义如下:

func (s *Server) closeListenersLocked() error {
	var err error
	for ln := range s.listeners {
		if cerr := (*ln).Close(); cerr != nil && err == nil {
			err = cerr
		}
	}
	return err
}

这一操作,就对应了在前文中讲解的 HTTP Server 优雅退出流程中的第 1 步,关闭所有开启的 net.Listener 对象。

接下来循环遍历 srv.onShutdown 中的函数,并依次启动新的 goroutine 对其进行调用。

onShutdown[]func() 类型,其切片内容正是在我们调用 srv.RegisterOnShutdown 的时候注册进来的。

srv.RegisterOnShutdown 定义如下:

// RegisterOnShutdown registers a function to call on [Server.Shutdown].
// This can be used to gracefully shutdown connections that have
// undergone ALPN protocol upgrade or that have been hijacked.
// This function should start protocol-specific graceful shutdown,
// but should not wait for shutdown to complete.
func (srv *Server) RegisterOnShutdown(f func()) {
	srv.mu.Lock()
	srv.onShutdown = append(srv.onShutdown, f)
	srv.mu.Unlock()
}

这是我们在前文中的使用示例:

// 可以注册一些 hook 函数,比如从注册中心下线逻辑
srv.RegisterOnShutdown(func() {
    log.Println("Register Shutdown 1")
})
srv.RegisterOnShutdown(func() {
    log.Println("Register Shutdown 2")
})

接着代码执行到这一步:

srv.listenerGroup.Wait()

根据这个操作的属性名和方法名可以猜到,listenerGroup 明显是 sync.WaitGroup 类型。

既然有 Wait(),那就应该会有 Add(1) 操作。在源码中搜索 listenerGroup.Add(1) 关键字,可以搜到如下方法:

// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
//
// It reports whether the server is still up (not Shutdown or Closed).
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.listeners == nil {
		s.listeners = make(map[*net.Listener]struct{})
	}
	if add {
		if s.shuttingDown() {
			return false
		}
		s.listeners[ln] = struct{}{}
		s.listenerGroup.Add(1)
	} else {
		delete(s.listeners, ln)
		s.listenerGroup.Done()
	}
	return true
}

trackListener 用于添加或移除一个 net.Listener 到已跟踪的监听器集合中。

这个方法会被 Serve 方法调用,而实际上我们执行 srv.ListenAndServe() 的方法内部,也是在调用 Serve 方法。

Serve 方法定义如下:

// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
//
// HTTP/2 support is only enabled if the Listener returns [*tls.Conn]
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
//
// Serve always returns a non-nil error and closes l.
// After [Server.Shutdown] or [Server.Close], the returned error is [ErrServerClosed].
func (srv *Server) Serve(l net.Listener) error {
	if fn := testHookServerServe; fn != nil {
		fn(srv, l) // call hook with unwrapped listener
	}

	origListener := l
	l = &onceCloseListener{Listener: l}
	defer l.Close()

	if err := srv.setupHTTP2_Serve(); err != nil {
		return err
	}

	// 将 `net.Listener` 添加到已跟踪的监听器集合中
	// 内部会通过调用 s.shuttingDown() 判断是否正在进行退出操作,如果是,则返回 ErrServerClosed
	if !srv.trackListener(&l, true) {
		return ErrServerClosed
	}
	// Serve 函数退出时,将 `net.Listener` 从已跟踪的监听器集合中移除
	defer srv.trackListener(&l, false)

	baseCtx := context.Background()
	if srv.BaseContext != nil {
		baseCtx = srv.BaseContext(origListener)
		if baseCtx == nil {
			panic("BaseContext returned a nil context")
		}
	}

	var tempDelay time.Duration // how long to sleep on accept failure

	ctx := context.WithValue(baseCtx, ServerContextKey, srv)
	for {
		rw, err := l.Accept()
		if err != nil {
			// 每次新的请求进来,先判断当前服务是否已经被标记为正在关闭,如果是,则直接返回 ErrServerClosed
			if srv.shuttingDown() {
				return ErrServerClosed
			}
			if ne, ok := err.(net.Error); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}
				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}
				srv.logf("http: Accept error: %v; retrying in %v", err, tempDelay)
				time.Sleep(tempDelay)
				continue
			}
			return err
		}
		connCtx := ctx
		if cc := srv.ConnContext; cc != nil {
			connCtx = cc(connCtx, rw)
			if connCtx == nil {
				panic("ConnContext returned nil")
			}
		}
		tempDelay = 0
		c := srv.newConn(rw)
		c.setState(c.rwc, StateNew, runHooks) // before Serve can return
		go c.serve(connCtx)
	}
}

Serve 方法内部,我们先重点关注如下代码段:

if !srv.trackListener(&l, true) {
    return ErrServerClosed
}
defer srv.trackListener(&l, false)

这说明 Serve 在启动的时候会将一个新的监听器(net.Listener)加入到 listeners 集合中。

Serve 函数退出时,会对其进行移除。

并且,srv.trackListener 内部又调用了 s.shuttingDown() 判断当前服务是否正在进行退出操作,如果是,则返回 ErrServerClosed

// 标记为关闭状态,就不会有请求进来,直接返回错误
if srv.shuttingDown() {
    return ErrServerClosed
}

shuttingDown 定义如下:

func (s *Server) shuttingDown() bool {
	return s.inShutdown.Load()
}

同理,在 for 循环中,每次 rw, err := l.Accept() 收到新的请求,都会先判断当前服务是否已经被标记为正在关闭,如果是,则直接返回 ErrServerClosed

这里其实就是在跟 Shutdown 方法中的 srv.inShutdown.Store(true) 进行配合操作。

Shutdown 收到优雅退出请求,就将 inShutdown 标记为 true。此时 Serve 方法内部为了不再接收新的请求进来,每次都会调用 s.shuttingDown() 进行判断。保证不会再有新的请求进来,导致 Shutdown 无法退出。

这跟前文讲解完全吻合,在 Shutdown 方法还没执行完成的时候,Serve 方法其实已经退出了。也是我们为什么将 srv.ListenAndServe() 代码放到子 goroutine 中的原因。

Shutdown 接着往下执行:

pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
    // Add 10% jitter.
    interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
    // Double and clamp for next time.
    pollIntervalBase *= 2
    if pollIntervalBase > shutdownPollIntervalMax {
        pollIntervalBase = shutdownPollIntervalMax
    }
    return interval
}

timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
    if srv.closeIdleConns() {
        return lnerr
    }
    select {
    case <-ctx.Done():
        return ctx.Err()
    case <-timer.C:
        timer.Reset(nextPollInterval())
    }
}

这一段逻辑比较多,并且 nextPollInterval 函数看起来比较迷惑,不过没关系,我们一点点来分析。

我们把这段代码中的 nextPollInterval 函数单独拿出来跑一下,就能大概知道它的意图了:

package main

import (
	"fmt"
	"math/rand"
	"time"
)

func main() {
	const shutdownPollIntervalMax = 500 * time.Millisecond
	pollIntervalBase := time.Millisecond
	nextPollInterval := func() time.Duration {
		// Add 10% jitter.
		interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
		// Double and clamp for next time.
		pollIntervalBase *= 2
		if pollIntervalBase > shutdownPollIntervalMax {
			pollIntervalBase = shutdownPollIntervalMax
		}
		return interval
	}

	for i := 0; i < 20; i++ {
		fmt.Println(nextPollInterval())
	}
}

执行这段程序,输入结果如下:

$ go run main.go
1.078014ms
2.007835ms
4.151327ms
8.474296ms
17.487625ms
34.403371ms
64.613106ms
136.696655ms
273.873977ms
516.290814ms
502.815326ms
516.160214ms
523.34143ms
537.808701ms
518.913897ms
526.711692ms
518.421559ms
527.229427ms
526.904891ms
502.738764ms

我们可以把随机数再去掉。

把这行代码:

interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))

改成这样:

interval := pollIntervalBase + time.Duration(pollIntervalBase/10)

重新执行这段程序,输入结果如下:

$ go run main.go
1.1ms
2.2ms
4.4ms
8.8ms
17.6ms
35.2ms
70.4ms
140.8ms
281.6ms
550ms
550ms
550ms
550ms
550ms
550ms
550ms
550ms
550ms
550ms
550ms

根据输出结果,我们可以清晰的看出,nextPollInterval 函数执行返回值,开始时按照 2 倍方式增长,最终固定在 550ms

pollIntervalBase 最终值等于 shutdownPollIntervalMax

根据公式计算 interval := pollIntervalBase + time.Duration(pollIntervalBase/10),即 interval = 500 + 500 / 10 = 550ms,计算结果与输出结果相吻合。

说白了,这段代码写这么复杂,其核心目的就是为 Shutdown 方法中等待空闲连接关闭的轮询操作设计一个动态的、带有抖动(jitter)的时间间隔。这种设计确保服务器在执行优雅退出时,能够有效地处理剩余的空闲连接,同时避免不必要的资源浪费。

现在来看 for 循环这段代码,就非常好理解了:

timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
	if srv.closeIdleConns() {
		return lnerr
	}
	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-timer.C:
		timer.Reset(nextPollInterval())
	}
}

这就是 Go 常用的定时器惯用法。

根据 nextPollInterval() 返回值大小,每次定时循环调用 srv.closeIdleConns() 方法。

并且这里有一个 case 执行了 case <-ctx.Done(),这正是我们调用 srv.Shutdown(ctx) 时,用来控制超时时间传递进来的 Context

另外,值得一提的是,在 Go 1.15 及以前的版本的 Shutdown 代码中这段定时器代码并不是这样实现的。

旧版本代码实现如下:

github.com/golang/go/b…

var shutdownPollInterval = 500 * time.Millisecond
...

ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
for {
    if srv.closeIdleConns() && srv.numListeners() == 0 {
        return lnerr
    }
    select {
    case <-ctx.Done():
        return ctx.Err()
    case <-ticker.C:
    }
}

旧版本代码实现更加简单,并没有使用 time.NewTimer,而是使用了 time.NewTicker。这样实现的好处是代码简单,逻辑清晰,没花哨的功能。

其实我们在工作中写代码也是一样的道理,先让代码 run 起来,后期再考虑优化的问题。Shutdown 方法在 Go 1.8 版本被加入,直到 Go 1.16 版本这段代码才发生改变。

旧版本代码使用 time.NewTicker 是因为每次定时循环的周期都是固定值,不需要改变。

新版本代码使用 time.NewTimer 是为了在每次循环周期中调用 timer.Reset 重置间隔时间。

这也是一个值得学习的小技巧。我们在工作中经常会遇到类似的需求:每隔一段时间,执行一次操作。最简单的方式就是使用 time.Sleep 来做间隔时长,然后就是 time.NewTickertime.NewTimer 这两种方式。这 3 种方式其实都能实现每隔一段时间执行一次操作,但它们适用场景又有所不同。

time.Sleep 是用阻塞当前 goroutine 的方式来实现的,它需要调度器先唤醒当前 goroutine,然后才能执行后续代码逻辑。

time.Ticker 创建了一个底层数据结构定时器 runtimeTimer,并且监听 runtimeTimer 计时结束后产生的信号。因为 Go 为其进行了优化,所以它的 CPU 消耗比 time.Sleep 小很多。

time.Timer 底层也是定时器 runtimeTimer,只不过我们可以方便的使用 timer.Reset 重置间隔时间。

所以这 3 者都有各自适用的场景。

现在我们需要继续跟踪的代码就剩下 srv.closeIdleConns() 了,根据方法命名我们也能大概猜测到它的用途就是为了关闭空闲连接。

closeIdleConns 方法定义如下:

// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	quiescent := true
	for c := range s.activeConn {
		st, unixSec := c.getState()
		// Issue 22682: treat StateNew connections as if
		// they're idle if we haven't read the first request's
		// header in over 5 seconds.
		//  这里预留 5s,防止在第一次读取连接头部信息时超过 5s
		if st == StateNew && unixSec < time.Now().Unix()-5 {
			st = StateIdle
		}
		if st != StateIdle || unixSec == 0 {
			// Assume unixSec == 0 means it's a very new
			// connection, without state set yet.
			// // unixSec == 0 代表这个连接是非常新的连接,则标志位被置为 false
			quiescent = false
			continue
		}
		c.rwc.Close()
		delete(s.activeConn, c)
	}
	return quiescent
}

这个方法比较核心,所以整个操作做了加锁处理。

使用 for 循环遍历所有连接,activeConn 是一个集合,类型为 map[*conn]struct{},里面记录了所有存活的连接。

c.getState() 能够获取连接的当前状态,对应的还有一个 setState 方法能够设置状态,setState 方法会在 Serve 方法中被调用。这其实就形成闭环了,每次有新的请求进来,都会设置连接状态(Serve 会根据当前处理请求的进度,将连接状态设置成 StateNewStateActiveStateIdleStateClosed 等),而在 Shutdown 方法中获取连接状态。

接着,代码中会判断连接中的请求是否已经完成操作(即:是否处于空闲状态 StateIdle),如果是,就直接将连接关闭,并从连接集合中移除,否则,跳过此次循环,等待下次循环周期。

这里调用 c.rwc.Close() 关闭连接,调用 delete(s.activeConn, c) 将当前连接从集合中移除,直到集合为空,表示全部连接已经被关闭释放,循环退出。

closeIdleConns 方法最终返回的 quiescent 标志位,是用来标记是否所有的连接都已经关闭。如果是,返回 true,否则,返回 false

这个方法的逻辑,其实就对应了前文讲解优雅退出流程中的第 2、3 两步。

至此,Shutdown 的源码就分析完成了。

Shutdown 方法的整个流程也完全是按照我们前文中讲解的优雅退出流程来的。

NOTE: 除了使用 Shutdown 进行优雅退出,net/http 包还为我们提供了 Close 方法用来强制退出,你可以自行尝试。

联系我