mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-19 15:19:55 +00:00
413 lines
14 KiB
Go
413 lines
14 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"image"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// testCapturer returns a 100x100 image for test sessions.
|
|
type testCapturer struct{}
|
|
|
|
func (t *testCapturer) Width() int { return 100 }
|
|
func (t *testCapturer) Height() int { return 100 }
|
|
func (t *testCapturer) Capture() (*image.RGBA, error) {
|
|
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
|
|
}
|
|
|
|
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
|
|
t.Helper()
|
|
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.SetDisableAuth(disableAuth)
|
|
if jwtConfig != nil {
|
|
srv.SetJWTConfig(jwtConfig)
|
|
}
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
return srv.listener.Addr(), srv
|
|
}
|
|
|
|
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
|
addr, _ := startTestServer(t, false, nil)
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Send session header: attach mode, no username, no JWT.
|
|
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
// Server should send RFB version then security failure.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
|
|
// Write client version to proceed through handshake.
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
// Read security types: 0 means failure, followed by reason.
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
|
|
|
var reasonLen [4]byte
|
|
_, err = io.ReadFull(conn, reasonLen[:])
|
|
require.NoError(t, err)
|
|
|
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
|
_, err = io.ReadFull(conn, reason)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
|
|
}
|
|
|
|
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
|
addr, _ := startTestServer(t, true, nil)
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Send session header: attach mode, no username, no JWT.
|
|
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
// Server should send RFB version.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
|
|
// Write client version.
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
// Should get security types (not 0 = failure).
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
|
}
|
|
|
|
// TestAuthEnabled_InvalidJWT_RejectedBeforeRFB confirms the VNC server itself
|
|
// (not just the JWT library) wires authentication into handleConnection. A
|
|
// well-formed JWT-shaped token must hit the server's validation path and be
|
|
// rejected with an AUTH_JWT_* reason, never reaching the RFB handshake.
|
|
func TestAuthEnabled_InvalidJWT_RejectedBeforeRFB(t *testing.T) {
|
|
addr, _ := startTestServer(t, false, &JWTConfig{
|
|
Issuer: "https://example.invalid",
|
|
KeysLocation: "https://example.invalid/.well-known/jwks.json",
|
|
Audiences: []string{"test"},
|
|
})
|
|
|
|
// Three-segment "JWT" with bogus base64. The server's authenticateJWT path
|
|
// must catch this regardless of the IdP being unreachable.
|
|
bogusJWT := "abc.def.ghi"
|
|
header := make([]byte, 3+2+len(bogusJWT)+4+4)
|
|
header[0] = ModeAttach
|
|
binary.BigEndian.PutUint16(header[1:3], 0) // username len
|
|
binary.BigEndian.PutUint16(header[3:5], uint16(len(bogusJWT)))
|
|
copy(header[5:5+len(bogusJWT)], bogusJWT)
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
require.Equal(t, byte(0), numTypes[0], "must fail security negotiation")
|
|
|
|
var reasonLen [4]byte
|
|
_, err = io.ReadFull(conn, reasonLen[:])
|
|
require.NoError(t, err)
|
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
|
_, err = io.ReadFull(conn, reason)
|
|
require.NoError(t, err)
|
|
// The reason must carry one of the server's AUTH_JWT_* codes, proving
|
|
// the rejection came from authenticateJWT in handleConnection.
|
|
r := string(reason)
|
|
hasJWTReject := false
|
|
for _, code := range []string{RejectCodeJWTInvalid, RejectCodeJWTExpired, RejectCodeAuthForbidden} {
|
|
if strings.Contains(r, code) {
|
|
hasJWTReject = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, hasJWTReject, "reason %q must include an AUTH_JWT_* code", r)
|
|
}
|
|
|
|
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
|
|
// content to a connection that fails source validation. Specifically, the
|
|
// server must close immediately and the client must see EOF before any RFB
|
|
// version greeting is written.
|
|
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.SetDisableAuth(true)
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
|
// the loopback short-circuit in isAllowedSource doesn't apply.
|
|
require.NoError(t, srv.Start(t.Context(), addr, netip.MustParsePrefix("10.99.0.0/16")))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second)))
|
|
|
|
// Reading even one byte must EOF: the source IP (127.0.0.1) is outside
|
|
// the configured overlay, so handleConnection closes before writing.
|
|
var b [1]byte
|
|
_, err = io.ReadFull(conn, b[:])
|
|
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
|
|
}
|
|
|
|
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
|
|
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
|
|
addr, _ := startTestServer(t, false, &JWTConfig{
|
|
Issuer: "https://example.com",
|
|
KeysLocation: "https://example.com/.well-known/jwks.json",
|
|
Audiences: []string{"test"},
|
|
})
|
|
|
|
conn, err := net.Dial("tcp", addr.String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Send session header with empty JWT.
|
|
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
|
|
}
|
|
|
|
func TestIsAllowedSource(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
localAddr netip.Addr
|
|
network netip.Prefix
|
|
remote net.Addr
|
|
want bool
|
|
}{
|
|
{
|
|
name: "non-tcp address rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.UDPAddr{IP: net.ParseIP("10.99.99.2"), Port: 1234},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "own IP rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.1"), Port: 5900},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "non-overlay IP rejected",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 5900},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "overlay IP allowed",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "v4-mapped v6 in overlay allowed (unmapped)",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.MustParsePrefix("10.99.0.0/16"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("::ffff:10.99.99.2"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "loopback allowed only when local is loopback",
|
|
localAddr: netip.MustParseAddr("127.0.0.1"),
|
|
network: netip.MustParsePrefix("127.0.0.0/8"),
|
|
remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.5"), Port: 5900},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "invalid network rejected (fail-closed)",
|
|
localAddr: netip.MustParseAddr("10.99.99.1"),
|
|
network: netip.Prefix{},
|
|
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
|
want: false,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.localAddr = tc.localAddr
|
|
srv.network = tc.network
|
|
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
|
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
|
assert.Contains(t, err.Error(), "invalid overlay network prefix")
|
|
}
|
|
|
|
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.SetDisableAuth(true)
|
|
srv.SetAgentToken("deadbeefcafebabe")
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
// Send a wrong token of the right length (8 bytes hex-decoded).
|
|
if _, err := conn.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil {
|
|
// Server may already have closed; either way the read below must EOF.
|
|
_ = err
|
|
}
|
|
|
|
// Server must close without sending the RFB greeting.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.Error(t, err, "server must close the connection on bad agent token")
|
|
}
|
|
|
|
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.SetDisableAuth(true)
|
|
const tokenHex = "deadbeefcafebabe"
|
|
srv.SetAgentToken(tokenHex)
|
|
token, err := hex.DecodeString(tokenHex)
|
|
require.NoError(t, err)
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
_, err = conn.Write(token)
|
|
require.NoError(t, err)
|
|
|
|
// Send session header so handleConnection can proceed past readConnectionHeader.
|
|
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
|
header[0] = ModeAttach
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
// With a matching token the server proceeds to the RFB greeting.
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err, "server must keep the connection open after a valid agent token")
|
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
|
}
|
|
|
|
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
|
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
|
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
|
srv.SetDisableAuth(true)
|
|
|
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
|
// Force vmgr to nil regardless of platform so the test is deterministic.
|
|
srv.vmgr = nil
|
|
t.Cleanup(func() { _ = srv.Stop() })
|
|
|
|
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
|
|
|
// ModeSession with no username/JWT, so we exit on the vmgr==nil branch
|
|
// before username validation runs.
|
|
header := []byte{ModeSession, 0, 0, 0, 0}
|
|
_, err = conn.Write(header)
|
|
require.NoError(t, err)
|
|
|
|
var version [12]byte
|
|
_, err = io.ReadFull(conn, version[:])
|
|
require.NoError(t, err)
|
|
_, err = conn.Write(version[:])
|
|
require.NoError(t, err)
|
|
|
|
var numTypes [1]byte
|
|
_, err = io.ReadFull(conn, numTypes[:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, byte(0), numTypes[0])
|
|
|
|
var reasonLen [4]byte
|
|
_, err = io.ReadFull(conn, reasonLen[:])
|
|
require.NoError(t, err)
|
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
|
_, err = io.ReadFull(conn, reason)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(reason), RejectCodeUnsupportedOS)
|
|
}
|