Files
netbird/proxy/internal/tcp/proxyprotocol_test.go

129 lines
4.2 KiB
Go

package tcp
import (
"bufio"
"net"
"testing"
"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteProxyProtoV2_IPv4(t *testing.T) {
// Set up a real TCP listener and dial to get connections with real addresses.
ln, err := net.Listen("tcp4", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
var serverConn net.Conn
accepted := make(chan struct{})
go func() {
var err error
serverConn, err = ln.Accept()
if err != nil {
t.Error("accept failed:", err)
}
close(accepted)
}()
clientConn, err := net.Dial("tcp4", ln.Addr().String())
require.NoError(t, err)
defer clientConn.Close()
<-accepted
defer serverConn.Close()
// Use a pipe as the backend: write the header to one end, read from the other.
backendRead, backendWrite := net.Pipe()
defer backendRead.Close()
defer backendWrite.Close()
// serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination.
writeDone := make(chan error, 1)
go func() {
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
}()
// Read the PROXY protocol header from the backend read side.
header, err := proxyproto.Read(bufio.NewReader(backendRead))
require.NoError(t, err)
require.NotNil(t, header, "should have received a proxy protocol header")
writeErr := <-writeDone
require.NoError(t, writeErr)
assert.Equal(t, byte(2), header.Version, "version should be 2")
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4")
// serverConn.RemoteAddr() is the client's address (source in the header).
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
actualSrc := header.SourceAddr.(*net.TCPAddr)
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
// serverConn.LocalAddr() is the server's address (destination in the header).
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
actualDst := header.DestinationAddr.(*net.TCPAddr)
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
}
func TestWriteProxyProtoV2_IPv6(t *testing.T) {
// Set up a real TCP6 listener on loopback.
ln, err := net.Listen("tcp6", "[::1]:0")
if err != nil {
t.Skip("IPv6 not available:", err)
}
defer ln.Close()
var serverConn net.Conn
accepted := make(chan struct{})
go func() {
var err error
serverConn, err = ln.Accept()
if err != nil {
t.Error("accept failed:", err)
}
close(accepted)
}()
clientConn, err := net.Dial("tcp6", ln.Addr().String())
require.NoError(t, err)
defer clientConn.Close()
<-accepted
defer serverConn.Close()
backendRead, backendWrite := net.Pipe()
defer backendRead.Close()
defer backendWrite.Close()
writeDone := make(chan error, 1)
go func() {
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
}()
header, err := proxyproto.Read(bufio.NewReader(backendRead))
require.NoError(t, err)
require.NotNil(t, header, "should have received a proxy protocol header")
writeErr := <-writeDone
require.NoError(t, writeErr)
assert.Equal(t, byte(2), header.Version, "version should be 2")
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6")
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
actualSrc := header.SourceAddr.(*net.TCPAddr)
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
actualDst := header.DestinationAddr.(*net.TCPAddr)
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
}