Files
netbird/proxy/internal/conntrack/hijacked.go

97 lines
2.4 KiB
Go

package conntrack
import (
"net/http"
"sync"
)
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
// upgrades). http.Server.Shutdown does not close hijacked connections, so
// they must be tracked and closed explicitly during graceful shutdown.
//
// Connections are indexed by the request Host so they can be closed
// per-domain when a service mapping is removed.
//
// Use Middleware as the outermost HTTP middleware to ensure hijacked
// connections are tracked and automatically deregistered when closed.
type HijackTracker struct {
mu sync.Mutex
conns map[*trackedConn]struct{}
}
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
// hijacked connections are tracked and automatically deregistered from the
// tracker when closed. This should be the outermost middleware in the chain.
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(&trackingWriter{
ResponseWriter: w,
tracker: t,
host: hostOnly(r.Host),
}, r)
})
}
// CloseAll closes all tracked hijacked connections and returns the count.
func (t *HijackTracker) CloseAll() int {
t.mu.Lock()
conns := t.conns
t.conns = nil
t.mu.Unlock()
for tc := range conns {
_ = tc.Conn.Close()
}
return len(conns)
}
// CloseByHost closes all tracked hijacked connections for the given host
// and returns the number of connections closed.
func (t *HijackTracker) CloseByHost(host string) int {
host = hostOnly(host)
t.mu.Lock()
var toClose []*trackedConn
for tc := range t.conns {
if tc.host == host {
toClose = append(toClose, tc)
}
}
for _, tc := range toClose {
delete(t.conns, tc)
}
t.mu.Unlock()
for _, tc := range toClose {
_ = tc.Conn.Close()
}
return len(toClose)
}
func (t *HijackTracker) add(tc *trackedConn) {
t.mu.Lock()
if t.conns == nil {
t.conns = make(map[*trackedConn]struct{})
}
t.conns[tc] = struct{}{}
t.mu.Unlock()
}
func (t *HijackTracker) remove(tc *trackedConn) {
t.mu.Lock()
delete(t.conns, tc)
t.mu.Unlock()
}
// hostOnly strips the port from a host:port string.
func hostOnly(hostport string) string {
for i := len(hostport) - 1; i >= 0; i-- {
if hostport[i] == ':' {
return hostport[:i]
}
if hostport[i] < '0' || hostport[i] > '9' {
return hostport
}
}
return hostport
}