kubernetes apiserver源码: SafeWaitGroup

201 阅读1分钟
// NonBlockingRun spawns the secure http server. An error is
// returned if the secure port cannot be listened on.
// The returned channel is closed when the (asynchronous) termination is finished.
func (s preparedGenericAPIServer) NonBlockingRun(stopCh <-chan struct{}) (<-chan struct{}, error) {
	

	// Use an internal stop channel to allow cleanup of the listeners on error.
	internalStopCh := make(chan struct{})
	var stoppedCh <-chan struct{}
	if s.SecureServingInfo != nil && s.Handler != nil {
		var err error
		stoppedCh, err = s.SecureServingInfo.Serve(s.Handler, s.ShutdownTimeout, internalStopCh)
		if err != nil {
			close(internalStopCh)
			close(auditStopCh)
			return nil, err
		}
	}

	// Now that listener have bound successfully, it is the
	// responsibility of the caller to close the provided channel to
	// ensure cleanup.
	go func() {
		<-stopCh // 等待停止信号
		close(internalStopCh)
		if stoppedCh != nil {
			<-stoppedCh
		}
		s.HandlerChainWaitGroup.Wait() // 等待所有请求处理完成
		close(auditStopCh)
	}()
            
        
        
	return stoppedCh, nil
}

等待所有请求处理完成(无损下线)是如何实现的?

SafeWaitGroup

这里是对sync.WaitGroup做了一个封装:

  • 当apiserver收到请求的时候,会调用Add方法添加一个请求计数
  • 当收到关闭信号时,调用了Wait方法就拒绝服务新的请求
  • 已接收的请求会不断的将请求计数递减(调用Done()方法),直到为0。Wait方法返回
package waitgroup

import (
	"fmt"
	"sync"
)

// SafeWaitGroup must not be copied after first use.
type SafeWaitGroup struct {
	wg sync.WaitGroup
	mu sync.RWMutex
	// wait indicate whether Wait is called, if true,
	// then any Add with positive delta will return error.
	wait bool
}

// Add adds delta, which may be negative, similar to sync.WaitGroup.
// If Add with a positive delta happens after Wait, it will return error,
// which prevent unsafe Add.
func (wg *SafeWaitGroup) Add(delta int) error {
	wg.mu.RLock()
	defer wg.mu.RUnlock()
	if wg.wait && delta > 0 {
		return fmt.Errorf("add with positive delta after Wait is forbidden")
	}
	wg.wg.Add(delta)
	return nil
}

// Done decrements the WaitGroup counter.
func (wg *SafeWaitGroup) Done() {
	wg.wg.Done()
}

// Wait blocks until the WaitGroup counter is zero.
func (wg *SafeWaitGroup) Wait() {
	wg.mu.Lock()
	wg.wait = true
	wg.mu.Unlock()
	wg.wg.Wait()
}

如何使用SafeWaitGroup?

在Handler处理链里加入其中一环

// WithWaitGroup adds all non long-running requests to wait group, which is used for graceful shutdown.
func WithWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg *utilwaitgroup.SafeWaitGroup) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		ctx := req.Context()
		requestInfo, ok := apirequest.RequestInfoFrom(ctx)
		if !ok {
			// if this happens, the handler chain isn't setup correctly because there is no request info
			responsewriters.InternalError(w, req, errors.New("no RequestInfo found in the context"))
			return
		}
                // 如果不是watch等请求
		if !longRunning(req, requestInfo) {
                        // 调用wg.Add(1)将请求计数+1
			if err := wg.Add(1); err != nil {
				// When apiserver is shutting down, signal clients to retry
				// There is a good chance the client hit a different server, so a tight retry is good for client responsiveness.
				w.Header().Add("Retry-After", "1")
				w.Header().Set("Content-Type", runtime.ContentTypeJSON)
				w.Header().Set("X-Content-Type-Options", "nosniff")
				statusErr := apierrors.NewServiceUnavailable("apiserver is shutting down").Status()
				w.WriteHeader(int(statusErr.Code))
				fmt.Fprintln(w, runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &statusErr))
				return
			}
                        // handler.ServeHTTP(w, req) 处理完请求,将请求计数减1
			defer wg.Done()
		}

		handler.ServeHTTP(w, req)
	})
}

怎样判断是否长连接?

简单来看,常用的watch等请求会被判定为长连接。这些长连接基本都是websocket读请求,不会对服务器状态进行更改。而无损下线的目的是防止服务状态出现不一致,而读请求不会更改状态,所以不用care这些读请求。

// BasicLongRunningRequestCheck returns true if the given request has one of the specified verbs or one of the specified subresources, or is a profiler request.
func BasicLongRunningRequestCheck(longRunningVerbs, longRunningSubresources sets.String) apirequest.LongRunningRequestCheck {
	return func(r *http.Request, requestInfo *apirequest.RequestInfo) bool {
		if longRunningVerbs.Has(requestInfo.Verb) {
			return true
		}
		if requestInfo.IsResourceRequest && longRunningSubresources.Has(requestInfo.Subresource) {
			return true
		}
		if !requestInfo.IsResourceRequest && strings.HasPrefix(requestInfo.Path, "/debug/pprof/") {
			return true
		}
		return false
	}
}