mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 15:46:38 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
142
proxy/internal/conntrack/hijacked_test.go
Normal file
142
proxy/internal/conntrack/hijacked_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing.
|
||||
type fakeHijackWriter struct {
|
||||
http.ResponseWriter
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
||||
return f.conn, rw, nil
|
||||
}
|
||||
|
||||
func TestCloseByHost(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
// Simulate hijacking two connections for different hosts.
|
||||
connA1, connA2 := net.Pipe()
|
||||
defer connA2.Close()
|
||||
connB1, connB2 := net.Pipe()
|
||||
defer connB2.Close()
|
||||
|
||||
twA := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "a.example.com",
|
||||
}
|
||||
twB := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "b.example.com",
|
||||
}
|
||||
|
||||
// Use fakeHijackWriter to provide the Hijack method.
|
||||
twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1}
|
||||
twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1}
|
||||
|
||||
_, _, err := twA.Hijack()
|
||||
require.NoError(t, err)
|
||||
_, _, err = twB.Hijack()
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 2, len(tracker.conns), "should track 2 connections")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close only host A.
|
||||
n := tracker.CloseByHost("a.example.com")
|
||||
assert.Equal(t, 1, n, "should close 1 connection for host A")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Verify host A's conn is actually closed.
|
||||
buf := make([]byte, 1)
|
||||
_, err = connA2.Read(buf)
|
||||
assert.Error(t, err, "host A pipe should be closed")
|
||||
|
||||
// Host B should still be alive.
|
||||
go func() { _, _ = connB1.Write([]byte("x")) }()
|
||||
|
||||
// Close all remaining.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 1, n, "should close remaining 1 connection")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCloseAll(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
for range 5 {
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"}
|
||||
tracker.add(tc)
|
||||
}
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 5, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
n := tracker.CloseAll()
|
||||
assert.Equal(t, 5, n)
|
||||
|
||||
// Double CloseAll is safe.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestTrackedConn_AutoDeregister(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"}
|
||||
tracker.add(tc)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close the tracked conn: should auto-deregister.
|
||||
require.NoError(t, tc.Close())
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestHostOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"example.com:443", "example.com"},
|
||||
{"example.com", "example.com"},
|
||||
{"127.0.0.1:8080", "127.0.0.1"},
|
||||
{"[::1]:443", "[::1]"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, hostOnly(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user