mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
97 lines
2.4 KiB
Go
97 lines
2.4 KiB
Go
package conntrack
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
)
|
|
|
|
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
|
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
|
// they must be tracked and closed explicitly during graceful shutdown.
|
|
//
|
|
// Connections are indexed by the request Host so they can be closed
|
|
// per-domain when a service mapping is removed.
|
|
//
|
|
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
|
// connections are tracked and automatically deregistered when closed.
|
|
type HijackTracker struct {
|
|
mu sync.Mutex
|
|
conns map[*trackedConn]struct{}
|
|
}
|
|
|
|
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
|
// hijacked connections are tracked and automatically deregistered from the
|
|
// tracker when closed. This should be the outermost middleware in the chain.
|
|
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
next.ServeHTTP(&trackingWriter{
|
|
ResponseWriter: w,
|
|
tracker: t,
|
|
host: hostOnly(r.Host),
|
|
}, r)
|
|
})
|
|
}
|
|
|
|
// CloseAll closes all tracked hijacked connections and returns the count.
|
|
func (t *HijackTracker) CloseAll() int {
|
|
t.mu.Lock()
|
|
conns := t.conns
|
|
t.conns = nil
|
|
t.mu.Unlock()
|
|
|
|
for tc := range conns {
|
|
_ = tc.Conn.Close()
|
|
}
|
|
return len(conns)
|
|
}
|
|
|
|
// CloseByHost closes all tracked hijacked connections for the given host
|
|
// and returns the number of connections closed.
|
|
func (t *HijackTracker) CloseByHost(host string) int {
|
|
host = hostOnly(host)
|
|
t.mu.Lock()
|
|
var toClose []*trackedConn
|
|
for tc := range t.conns {
|
|
if tc.host == host {
|
|
toClose = append(toClose, tc)
|
|
}
|
|
}
|
|
for _, tc := range toClose {
|
|
delete(t.conns, tc)
|
|
}
|
|
t.mu.Unlock()
|
|
|
|
for _, tc := range toClose {
|
|
_ = tc.Conn.Close()
|
|
}
|
|
return len(toClose)
|
|
}
|
|
|
|
func (t *HijackTracker) add(tc *trackedConn) {
|
|
t.mu.Lock()
|
|
if t.conns == nil {
|
|
t.conns = make(map[*trackedConn]struct{})
|
|
}
|
|
t.conns[tc] = struct{}{}
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
func (t *HijackTracker) remove(tc *trackedConn) {
|
|
t.mu.Lock()
|
|
delete(t.conns, tc)
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
// hostOnly strips the port from a host:port string.
|
|
func hostOnly(hostport string) string {
|
|
for i := len(hostport) - 1; i >= 0; i-- {
|
|
if hostport[i] == ':' {
|
|
return hostport[:i]
|
|
}
|
|
if hostport[i] < '0' || hostport[i] > '9' {
|
|
return hostport
|
|
}
|
|
}
|
|
return hostport
|
|
}
|