mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
129 lines
4.2 KiB
Go
129 lines
4.2 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bufio"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/pires/go-proxyproto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestWriteProxyProtoV2_IPv4(t *testing.T) {
|
|
// Set up a real TCP listener and dial to get connections with real addresses.
|
|
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
var serverConn net.Conn
|
|
accepted := make(chan struct{})
|
|
go func() {
|
|
var err error
|
|
serverConn, err = ln.Accept()
|
|
if err != nil {
|
|
t.Error("accept failed:", err)
|
|
}
|
|
close(accepted)
|
|
}()
|
|
|
|
clientConn, err := net.Dial("tcp4", ln.Addr().String())
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
<-accepted
|
|
defer serverConn.Close()
|
|
|
|
// Use a pipe as the backend: write the header to one end, read from the other.
|
|
backendRead, backendWrite := net.Pipe()
|
|
defer backendRead.Close()
|
|
defer backendWrite.Close()
|
|
|
|
// serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination.
|
|
writeDone := make(chan error, 1)
|
|
go func() {
|
|
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
|
}()
|
|
|
|
// Read the PROXY protocol header from the backend read side.
|
|
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, header, "should have received a proxy protocol header")
|
|
|
|
writeErr := <-writeDone
|
|
require.NoError(t, writeErr)
|
|
|
|
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
|
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
|
assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4")
|
|
|
|
// serverConn.RemoteAddr() is the client's address (source in the header).
|
|
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
|
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
|
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
|
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
|
|
|
// serverConn.LocalAddr() is the server's address (destination in the header).
|
|
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
|
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
|
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
|
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
|
}
|
|
|
|
func TestWriteProxyProtoV2_IPv6(t *testing.T) {
|
|
// Set up a real TCP6 listener on loopback.
|
|
ln, err := net.Listen("tcp6", "[::1]:0")
|
|
if err != nil {
|
|
t.Skip("IPv6 not available:", err)
|
|
}
|
|
defer ln.Close()
|
|
|
|
var serverConn net.Conn
|
|
accepted := make(chan struct{})
|
|
go func() {
|
|
var err error
|
|
serverConn, err = ln.Accept()
|
|
if err != nil {
|
|
t.Error("accept failed:", err)
|
|
}
|
|
close(accepted)
|
|
}()
|
|
|
|
clientConn, err := net.Dial("tcp6", ln.Addr().String())
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
<-accepted
|
|
defer serverConn.Close()
|
|
|
|
backendRead, backendWrite := net.Pipe()
|
|
defer backendRead.Close()
|
|
defer backendWrite.Close()
|
|
|
|
writeDone := make(chan error, 1)
|
|
go func() {
|
|
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
|
}()
|
|
|
|
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, header, "should have received a proxy protocol header")
|
|
|
|
writeErr := <-writeDone
|
|
require.NoError(t, writeErr)
|
|
|
|
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
|
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
|
assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6")
|
|
|
|
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
|
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
|
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
|
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
|
|
|
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
|
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
|
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
|
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
|
}
|