mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -10,10 +10,11 @@ import (
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
host string
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
c.tracker.remove(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -22,6 +23,7 @@ func (c *trackedConn) Close() error {
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
host string
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
@@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host}
|
||||
w.tracker.add(tc)
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
@@ -10,10 +9,14 @@ import (
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Connections are indexed by the request Host so they can be closed
|
||||
// per-domain when a service mapping is removed.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
mu sync.Mutex
|
||||
conns map[*trackedConn]struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
@@ -21,21 +24,73 @@ type HijackTracker struct {
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
next.ServeHTTP(&trackingWriter{
|
||||
ResponseWriter: w,
|
||||
tracker: t,
|
||||
host: hostOnly(r.Host),
|
||||
}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
// CloseAll closes all tracked hijacked connections and returns the count.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
t.mu.Lock()
|
||||
conns := t.conns
|
||||
t.conns = nil
|
||||
t.mu.Unlock()
|
||||
|
||||
for tc := range conns {
|
||||
_ = tc.Conn.Close()
|
||||
}
|
||||
return len(conns)
|
||||
}
|
||||
|
||||
// CloseByHost closes all tracked hijacked connections for the given host
|
||||
// and returns the number of connections closed.
|
||||
func (t *HijackTracker) CloseByHost(host string) int {
|
||||
host = hostOnly(host)
|
||||
t.mu.Lock()
|
||||
var toClose []*trackedConn
|
||||
for tc := range t.conns {
|
||||
if tc.host == host {
|
||||
toClose = append(toClose, tc)
|
||||
}
|
||||
}
|
||||
for _, tc := range toClose {
|
||||
delete(t.conns, tc)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
for _, tc := range toClose {
|
||||
_ = tc.Conn.Close()
|
||||
}
|
||||
return len(toClose)
|
||||
}
|
||||
|
||||
func (t *HijackTracker) add(tc *trackedConn) {
|
||||
t.mu.Lock()
|
||||
if t.conns == nil {
|
||||
t.conns = make(map[*trackedConn]struct{})
|
||||
}
|
||||
t.conns[tc] = struct{}{}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *HijackTracker) remove(tc *trackedConn) {
|
||||
t.mu.Lock()
|
||||
delete(t.conns, tc)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// hostOnly strips the port from a host:port string.
|
||||
func hostOnly(hostport string) string {
|
||||
for i := len(hostport) - 1; i >= 0; i-- {
|
||||
if hostport[i] == ':' {
|
||||
return hostport[:i]
|
||||
}
|
||||
if hostport[i] < '0' || hostport[i] > '9' {
|
||||
return hostport
|
||||
}
|
||||
}
|
||||
return hostport
|
||||
}
|
||||
|
||||
142
proxy/internal/conntrack/hijacked_test.go
Normal file
142
proxy/internal/conntrack/hijacked_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing.
|
||||
type fakeHijackWriter struct {
|
||||
http.ResponseWriter
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
||||
return f.conn, rw, nil
|
||||
}
|
||||
|
||||
func TestCloseByHost(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
// Simulate hijacking two connections for different hosts.
|
||||
connA1, connA2 := net.Pipe()
|
||||
defer connA2.Close()
|
||||
connB1, connB2 := net.Pipe()
|
||||
defer connB2.Close()
|
||||
|
||||
twA := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "a.example.com",
|
||||
}
|
||||
twB := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "b.example.com",
|
||||
}
|
||||
|
||||
// Use fakeHijackWriter to provide the Hijack method.
|
||||
twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1}
|
||||
twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1}
|
||||
|
||||
_, _, err := twA.Hijack()
|
||||
require.NoError(t, err)
|
||||
_, _, err = twB.Hijack()
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 2, len(tracker.conns), "should track 2 connections")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close only host A.
|
||||
n := tracker.CloseByHost("a.example.com")
|
||||
assert.Equal(t, 1, n, "should close 1 connection for host A")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Verify host A's conn is actually closed.
|
||||
buf := make([]byte, 1)
|
||||
_, err = connA2.Read(buf)
|
||||
assert.Error(t, err, "host A pipe should be closed")
|
||||
|
||||
// Host B should still be alive.
|
||||
go func() { _, _ = connB1.Write([]byte("x")) }()
|
||||
|
||||
// Close all remaining.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 1, n, "should close remaining 1 connection")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCloseAll(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
for range 5 {
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"}
|
||||
tracker.add(tc)
|
||||
}
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 5, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
n := tracker.CloseAll()
|
||||
assert.Equal(t, 5, n)
|
||||
|
||||
// Double CloseAll is safe.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestTrackedConn_AutoDeregister(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"}
|
||||
tracker.add(tc)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close the tracked conn: should auto-deregister.
|
||||
require.NoError(t, tc.Close())
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestHostOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"example.com:443", "example.com"},
|
||||
{"example.com", "example.com"},
|
||||
{"127.0.0.1:8080", "127.0.0.1"},
|
||||
{"[::1]:443", "[::1]"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, hostOnly(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user