gitea/modules/graceful/server.go
Lunny Xiao 304d836a61
Fix shutdown waitgroup panic (#35676)
This PR fixes a panic issue in the WaitGroup that occurs when Gitea is
shut down using Ctrl+C.
It ensures that all active connection pointers in the server are
properly tracked and forcibly closed when the hammer shutdown is
invoked.
The process remains graceful — the normal shutdown sequence runs before
the hammer is triggered, and existing connections are given a timeout
period to complete gracefully.

This PR also fixes `no logger writer` problem. Now the log close will
only be invoked when the command exit.

- Fixes #35468
- Fixes #35551
- Fixes #35559
- Replace #35578

---------

Co-authored-by: wxiaoguang <wxiaoguang@gmail.com>
2025-10-25 00:02:58 -07:00

294 lines
8.0 KiB
Go

// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
// This code is highly inspired by endless go
package graceful
import (
"crypto/tls"
"net"
"os"
"strings"
"sync"
"syscall"
"time"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/proxyprotocol"
"code.gitea.io/gitea/modules/setting"
)
// GetListener returns a net listener
// This determines the implementation of net.Listener which the server will use,
// so that downstreams could provide their own Listener, such as with a hidden service or a p2p network
var GetListener = DefaultGetListener
// ServeFunction represents a listen.Accept loop
type ServeFunction = func(net.Listener) error
// Server represents our graceful server
type Server struct {
network string
address string
listener net.Listener
lock sync.RWMutex
state state
connCounter int64
connEmptyCond *sync.Cond
BeforeBegin func(network, address string)
OnShutdown func()
PerWriteTimeout time.Duration
PerWritePerKbTimeout time.Duration
}
// NewServer creates a server on network at provided address
func NewServer(network, address, name string) *Server {
if GetManager().IsChild() {
log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
} else {
log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
}
srv := &Server{
state: stateInit,
network: network,
address: address,
PerWriteTimeout: setting.PerWriteTimeout,
PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
}
srv.connEmptyCond = sync.NewCond(&srv.lock)
srv.BeforeBegin = func(network, addr string) {
log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
}
return srv
}
// ListenAndServe listens on the provided network address and then calls Serve
// to handle requests on incoming connections.
func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
go srv.awaitShutdown()
listener, err := GetListener(srv.network, srv.address)
if err != nil {
log.Error("Unable to GetListener: %v", err)
return err
}
// we need to wrap the listener to take account of our lifecycle
listener = newWrappedListener(listener, srv)
// Now we need to take account of ProxyProtocol settings...
if useProxyProtocol {
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}
srv.listener = listener
srv.BeforeBegin(srv.network, srv.address)
return srv.Serve(serve)
}
// ListenAndServeTLSConfig listens on the provided network address and then calls
// Serve to handle requests on incoming TLS connections.
func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
go srv.awaitShutdown()
if tlsConfig.MinVersion == 0 {
tlsConfig.MinVersion = tls.VersionTLS12
}
listener, err := GetListener(srv.network, srv.address)
if err != nil {
log.Error("Unable to get Listener: %v", err)
return err
}
// we need to wrap the listener to take account of our lifecycle
listener = newWrappedListener(listener, srv)
// Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
if useProxyProtocol && !proxyProtocolTLSBridging {
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}
// Now handle the tls protocol
listener = tls.NewListener(listener, tlsConfig)
// Now if we're bridging then we need the proxy to tell us who we're bridging for...
if useProxyProtocol && proxyProtocolTLSBridging {
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}
srv.listener = listener
srv.BeforeBegin(srv.network, srv.address)
return srv.Serve(serve)
}
// Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
// service goroutine for each. The service goroutines read requests and then call
// handler to reply to them. Handler is typically nil, in which case the
// DefaultServeMux is used.
//
// In addition to the standard Serve behaviour each connection is added to a
// sync.Waitgroup so that all outstanding connections can be served before shutting
// down the server.
func (srv *Server) Serve(serve ServeFunction) error {
defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
srv.setState(stateRunning)
GetManager().RegisterServer()
err := serve(srv.listener)
log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
srv.waitForActiveConnections()
srv.setState(stateTerminate)
GetManager().ServerDone()
// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
return nil
}
return err
}
func (srv *Server) getState() state {
srv.lock.RLock()
defer srv.lock.RUnlock()
return srv.state
}
func (srv *Server) setState(st state) {
srv.lock.Lock()
defer srv.lock.Unlock()
srv.state = st
}
func (srv *Server) waitForActiveConnections() {
srv.lock.Lock()
for srv.connCounter > 0 {
srv.connEmptyCond.Wait()
}
srv.lock.Unlock()
}
func (srv *Server) wrapConnection(c net.Conn) (net.Conn, error) {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.state != stateRunning {
_ = c.Close()
return nil, syscall.EINVAL // same as AcceptTCP
}
srv.connCounter++
return &wrappedConn{Conn: c, server: srv}, nil
}
func (srv *Server) removeConnection(_ *wrappedConn) {
srv.lock.Lock()
defer srv.lock.Unlock()
srv.connCounter--
if srv.connCounter <= 0 {
srv.connEmptyCond.Broadcast()
}
}
// closeAllConnections forcefully closes all active connections
func (srv *Server) closeAllConnections() {
srv.lock.Lock()
if srv.connCounter > 0 {
log.Warn("After graceful shutdown period, %d connections are still active. Forcefully close.", srv.connCounter)
srv.connCounter = 0 // OS will close all the connections after the process exits, so we just assume there is no active connection now
}
srv.lock.Unlock()
srv.connEmptyCond.Broadcast()
}
type filer interface {
File() (*os.File, error)
}
type wrappedListener struct {
net.Listener
server *Server
}
var (
_ net.Listener = (*wrappedListener)(nil)
_ filer = (*wrappedListener)(nil)
)
func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
return &wrappedListener{
Listener: l,
server: srv,
}
}
func (wl *wrappedListener) Accept() (c net.Conn, err error) {
if tcl, ok := wl.Listener.(*net.TCPListener); ok {
// Set keepalive on TCPListeners connections if possible, see http.tcpKeepAliveListener
tc, err := tcl.AcceptTCP()
if err != nil {
return nil, err
}
_ = tc.SetKeepAlive(true)
_ = tc.SetKeepAlivePeriod(3 * time.Minute)
c = tc
} else {
c, err = wl.Listener.Accept()
if err != nil {
return nil, err
}
}
return wl.server.wrapConnection(c)
}
func (wl *wrappedListener) File() (*os.File, error) {
// returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
return wl.Listener.(filer).File()
}
type wrappedConn struct {
net.Conn
server *Server
deadline time.Time
}
func (w *wrappedConn) Write(p []byte) (n int, err error) {
if w.server.PerWriteTimeout > 0 {
minTimeout := time.Duration(len(p)/1024) * w.server.PerWritePerKbTimeout
minDeadline := time.Now().Add(minTimeout).Add(w.server.PerWriteTimeout)
w.deadline = w.deadline.Add(minTimeout)
if minDeadline.After(w.deadline) {
w.deadline = minDeadline
}
_ = w.Conn.SetWriteDeadline(w.deadline)
}
return w.Conn.Write(p)
}
func (w *wrappedConn) Close() error {
w.server.removeConnection(w)
return w.Conn.Close()
}