mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 13:09:55 +00:00
319 lines
10 KiB
Go
319 lines
10 KiB
Go
//go:build !js && !ios && !android
|
|
|
|
package server
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"image"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// testCapturer returns a 100x100 image for test sessions.
|
|
type testCapturer struct{}
|
|
|
|
func (t *testCapturer) Width() int { return 100 }
|
|
func (t *testCapturer) Height() int { return 100 }
|
|
func (t *testCapturer) Capture() (*image.RGBA, error) {
|
|
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
|
|
}
|
|
|
|
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
|
|
t.Helper()
|
|
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.SetDisableAuth(disableAuth)
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
return srv.listener.Addr(), srv
|
|
}
|
|
|
|
func TestAuthEnabled_NoSessionAuth_RejectsConnection(t *testing.T) {
|
|
addr, _ := startTestServer(t, false)
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Header with no Noise handshake. Auth-required servers must reject
|
|
// because no client static was authenticated.
|
|
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
|
|
|
var reasonLen [4]byte
|
|
_, err = io.ReadFull(conn, reasonLen[:])
|
|
require.NoError(t, err)
|
|
|
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
|
_, err = io.ReadFull(conn, reason)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(reason), "identity proof missing", "rejection reason should mention missing identity proof")
|
|
}
|
|
|
|
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
|
addr, _ := startTestServer(t, true)
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
// Server should send RFB version.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
|
|
// Write client version.
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
// Should get security types (not 0 = failure).
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
|
}
|
|
|
|
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
|
|
// content to a connection that fails source validation. Specifically, the
|
|
// server must close immediately and the client must see EOF before any RFB
|
|
// version greeting is written.
|
|
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.SetDisableAuth(true)
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
|
// the loopback short-circuit in isAllowedSource doesn't apply.
|
|
require.NoError(t, srv.Start(t.Context(), addr, netip.MustParsePrefix("10.99.0.0/16")))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second)))
|
|
|
|
// Reading even one byte must EOF: the source IP (127.0.0.1) is outside
|
|
// the configured overlay, so handleConnection closes before writing.
|
|
var b [1]byte
|
|
_, err = io.ReadFull(conn, b[:])
|
|
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
|
|
}
|
|
|
|
func TestIsAllowedSource(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
localAddr netip.Addr
|
|
network netip.Prefix
|
|
remote net.Addr
|
|
want bool
|
|
}{
|
|
{
|
|
name: "non-tcp address rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.UDPAddr{IP: net.ParseIP("10.99.99.2"), Port: 1234},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "own IP rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.1"), Port: 5900},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "non-overlay IP rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 5900},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "overlay IP allowed",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "v4-mapped v6 in overlay allowed (unmapped)",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("::ffff:10.99.99.2"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "loopback allowed only when local is loopback",
|
|
localAddr: netip.MustParseAddr("127.0.0.1"),
|
|
network: netip.MustParsePrefix("127.0.0.0/8"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.5"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "invalid network rejected (fail-closed)",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.Prefix{},
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
|
want: false,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.localAddr = tc.localAddr
|
|
srv.network = tc.network
|
|
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
|
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
|
assert.Contains(t, err.Error(), "invalid overlay network prefix")
|
|
}
|
|
|
|
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.SetDisableAuth(true)
|
|
srv.SetAgentToken("deadbeefcafebabe")
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
// Send a wrong token of the right length (8 bytes hex-decoded).
|
|
if _, err := conn.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil {
|
|
// Server may already have closed; either way the read below must EOF.
|
|
_ = err
|
|
}
|
|
|
|
// Server must close without sending the RFB greeting.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.Error(t, err, "server must close the connection on bad agent token")
|
|
}
|
|
|
|
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.SetDisableAuth(true)
|
|
const tokenHex = "deadbeefcafebabe"
|
|
srv.SetAgentToken(tokenHex)
|
|
token, err := hex.DecodeString(tokenHex)
|
|
require.NoError(t, err)
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
_, err = conn.Write(token)
|
|
require.NoError(t, err)
|
|
|
|
// Send session header so handleConnection can proceed past readConnectionHeader.
|
|
header := make([]byte, 11) // ModeAttach + usernameLen=0 + sessionID=0 + width=0 + height=0
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
// With a matching token the server proceeds to the RFB greeting.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err, "server must keep the connection open after a valid agent token")
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
}
|
|
|
|
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
|
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
|
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
|
srv.SetDisableAuth(true)
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
// Force vmgr to nil regardless of platform so the test is deterministic.
|
|
srv.vmgr = nil
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
// ModeSession with no username, so we exit on the vmgr==nil branch
|
|
// before username validation runs.
|
|
header := []byte{ModeSession, 0, 0, 0, 0}
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, byte(0), numTypes[0])
|
|
|
|
var reasonLen [4]byte
|
|
_, err = io.ReadFull(conn, reasonLen[:])
|
|
require.NoError(t, err)
|
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
|
_, err = io.ReadFull(conn, reason)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(reason), RejectCodeUnsupportedOS)
|
|
}
|