package server import ( "encoding/binary" "encoding/hex" "image" "io" "net" "net/netip" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // testCapturer returns a 100x100 image for test sessions. type testCapturer struct{} func (t *testCapturer) Width() int { return 100 } func (t *testCapturer) Height() int { return 100 } func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil } func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) { t.Helper() srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.SetDisableAuth(disableAuth) if jwtConfig != nil { srv.SetJWTConfig(jwtConfig) } addr := netip.MustParseAddrPort("127.0.0.1:0") network := netip.MustParsePrefix("127.0.0.0/8") require.NoError(t, srv.Start(t.Context(), addr, network)) // Override local address so source validation doesn't reject 127.0.0.1 as "own IP". srv.localAddr = netip.MustParseAddr("10.99.99.1") t.Cleanup(func() { _ = srv.Stop() }) return srv.listener.Addr(), srv } func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) { addr, _ := startTestServer(t, false, nil) conn, err := net.Dial("tcp", addr.String()) require.NoError(t, err) defer conn.Close() // Send session header: attach mode, no username, no JWT. header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0 header[0] = ModeAttach _, err = conn.Write(header) require.NoError(t, err) // Server should send RFB version then security failure. var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err) assert.Equal(t, "RFB 003.008\n", string(version[:])) // Write client version to proceed through handshake. _, err = conn.Write(version[:]) require.NoError(t, err) // Read security types: 0 means failure, followed by reason. var numTypes [1]byte _, err = io.ReadFull(conn, numTypes[:]) require.NoError(t, err) assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)") var reasonLen [4]byte _, err = io.ReadFull(conn, reasonLen[:]) require.NoError(t, err) reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:])) _, err = io.ReadFull(conn, reason) require.NoError(t, err) assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config") } func TestAuthDisabled_AllowsConnection(t *testing.T) { addr, _ := startTestServer(t, true, nil) conn, err := net.Dial("tcp", addr.String()) require.NoError(t, err) defer conn.Close() // Send session header: attach mode, no username, no JWT. header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0 header[0] = ModeAttach _, err = conn.Write(header) require.NoError(t, err) // Server should send RFB version. var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err) assert.Equal(t, "RFB 003.008\n", string(version[:])) // Write client version. _, err = conn.Write(version[:]) require.NoError(t, err) // Should get security types (not 0 = failure). var numTypes [1]byte _, err = io.ReadFull(conn, numTypes[:]) require.NoError(t, err) assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)") } // TestAuthEnabled_InvalidJWT_RejectedBeforeRFB confirms the VNC server itself // (not just the JWT library) wires authentication into handleConnection. A // well-formed JWT-shaped token must hit the server's validation path and be // rejected with an AUTH_JWT_* reason, never reaching the RFB handshake. func TestAuthEnabled_InvalidJWT_RejectedBeforeRFB(t *testing.T) { addr, _ := startTestServer(t, false, &JWTConfig{ Issuer: "https://example.invalid", KeysLocation: "https://example.invalid/.well-known/jwks.json", Audiences: []string{"test"}, }) // Three-segment "JWT" with bogus base64. The server's authenticateJWT path // must catch this regardless of the IdP being unreachable. bogusJWT := "abc.def.ghi" header := make([]byte, 3+2+len(bogusJWT)+4+4) header[0] = ModeAttach binary.BigEndian.PutUint16(header[1:3], 0) // username len binary.BigEndian.PutUint16(header[3:5], uint16(len(bogusJWT))) copy(header[5:5+len(bogusJWT)], bogusJWT) conn, err := net.Dial("tcp", addr.String()) require.NoError(t, err) defer conn.Close() require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second))) _, err = conn.Write(header) require.NoError(t, err) var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err) _, err = conn.Write(version[:]) require.NoError(t, err) var numTypes [1]byte _, err = io.ReadFull(conn, numTypes[:]) require.NoError(t, err) require.Equal(t, byte(0), numTypes[0], "must fail security negotiation") var reasonLen [4]byte _, err = io.ReadFull(conn, reasonLen[:]) require.NoError(t, err) reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:])) _, err = io.ReadFull(conn, reason) require.NoError(t, err) // The reason must carry one of the server's AUTH_JWT_* codes, proving // the rejection came from authenticateJWT in handleConnection. r := string(reason) hasJWTReject := false for _, code := range []string{RejectCodeJWTInvalid, RejectCodeJWTExpired, RejectCodeAuthForbidden} { if strings.Contains(r, code) { hasJWTReject = true break } } assert.True(t, hasJWTReject, "reason %q must include an AUTH_JWT_* code", r) } // TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB // content to a connection that fails source validation. Specifically, the // server must close immediately and the client must see EOF before any RFB // version greeting is written. func TestAuth_NoUnauthBytesPastHeader(t *testing.T) { srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.SetDisableAuth(true) addr := netip.MustParseAddrPort("127.0.0.1:0") // Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so // the loopback short-circuit in isAllowedSource doesn't apply. require.NoError(t, srv.Start(t.Context(), addr, netip.MustParsePrefix("10.99.0.0/16"))) srv.localAddr = netip.MustParseAddr("10.99.99.1") t.Cleanup(func() { _ = srv.Stop() }) conn, err := net.Dial("tcp", srv.listener.Addr().String()) require.NoError(t, err) defer conn.Close() require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second))) // Reading even one byte must EOF: the source IP (127.0.0.1) is outside // the configured overlay, so handleConnection closes before writing. var b [1]byte _, err = io.ReadFull(conn, b[:]) require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting") } func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) { // Auth enabled with a (bogus) JWT config: connections without JWT should be rejected. addr, _ := startTestServer(t, false, &JWTConfig{ Issuer: "https://example.com", KeysLocation: "https://example.com/.well-known/jwks.json", Audiences: []string{"test"}, }) conn, err := net.Dial("tcp", addr.String()) require.NoError(t, err) defer conn.Close() // Send session header with empty JWT. header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0 header[0] = ModeAttach _, err = conn.Write(header) require.NoError(t, err) var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err) _, err = conn.Write(version[:]) require.NoError(t, err) var numTypes [1]byte _, err = io.ReadFull(conn, numTypes[:]) require.NoError(t, err) assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types") } func TestIsAllowedSource(t *testing.T) { tests := []struct { name string localAddr netip.Addr network netip.Prefix remote net.Addr want bool }{ { name: "non-tcp address rejected", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.MustParsePrefix("10.99.0.0/16"), remote: &net.UDPAddr{IP: net.ParseIP("10.99.99.2"), Port: 1234}, want: false, }, { name: "own IP rejected", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.MustParsePrefix("10.99.0.0/16"), remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.1"), Port: 5900}, want: false, }, { name: "non-overlay IP rejected", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.MustParsePrefix("10.99.0.0/16"), remote: &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 5900}, want: false, }, { name: "overlay IP allowed", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.MustParsePrefix("10.99.0.0/16"), remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900}, want: true, }, { name: "v4-mapped v6 in overlay allowed (unmapped)", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.MustParsePrefix("10.99.0.0/16"), remote: &net.TCPAddr{IP: net.ParseIP("::ffff:10.99.99.2"), Port: 5900}, want: true, }, { name: "loopback allowed only when local is loopback", localAddr: netip.MustParseAddr("127.0.0.1"), network: netip.MustParsePrefix("127.0.0.0/8"), remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.5"), Port: 5900}, want: true, }, { name: "invalid network rejected (fail-closed)", localAddr: netip.MustParseAddr("10.99.99.1"), network: netip.Prefix{}, remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900}, want: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.localAddr = tc.localAddr srv.network = tc.network assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote)) }) } } func TestStart_InvalidNetworkRejected(t *testing.T) { srv := New(&testCapturer{}, &StubInputInjector{}, "") addr := netip.MustParseAddrPort("127.0.0.1:0") err := srv.Start(t.Context(), addr, netip.Prefix{}) require.Error(t, err, "Start must refuse an invalid overlay prefix") assert.Contains(t, err.Error(), "invalid overlay network prefix") } func TestAgentToken_MismatchClosesConnection(t *testing.T) { srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.SetDisableAuth(true) srv.SetAgentToken("deadbeefcafebabe") addr := netip.MustParseAddrPort("127.0.0.1:0") network := netip.MustParsePrefix("127.0.0.0/8") require.NoError(t, srv.Start(t.Context(), addr, network)) srv.localAddr = netip.MustParseAddr("10.99.99.1") t.Cleanup(func() { _ = srv.Stop() }) conn, err := net.Dial("tcp", srv.listener.Addr().String()) require.NoError(t, err) defer conn.Close() require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second))) // Send a wrong token of the right length (8 bytes hex-decoded). if _, err := conn.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil { // Server may already have closed; either way the read below must EOF. _ = err } // Server must close without sending the RFB greeting. var version [12]byte _, err = io.ReadFull(conn, version[:]) require.Error(t, err, "server must close the connection on bad agent token") } func TestAgentToken_MatchAllowsHandshake(t *testing.T) { srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.SetDisableAuth(true) const tokenHex = "deadbeefcafebabe" srv.SetAgentToken(tokenHex) token, err := hex.DecodeString(tokenHex) require.NoError(t, err) addr := netip.MustParseAddrPort("127.0.0.1:0") network := netip.MustParsePrefix("127.0.0.0/8") require.NoError(t, srv.Start(t.Context(), addr, network)) srv.localAddr = netip.MustParseAddr("10.99.99.1") t.Cleanup(func() { _ = srv.Stop() }) conn, err := net.Dial("tcp", srv.listener.Addr().String()) require.NoError(t, err) defer conn.Close() require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second))) _, err = conn.Write(token) require.NoError(t, err) // Send session header so handleConnection can proceed past readConnectionHeader. header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0 header[0] = ModeAttach _, err = conn.Write(header) require.NoError(t, err) // With a matching token the server proceeds to the RFB greeting. var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err, "server must keep the connection open after a valid agent token") assert.Equal(t, "RFB 003.008\n", string(version[:])) } func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) { // Default platformSessionManager() on non-Linux returns nil, so ModeSession // must be rejected with the UNSUPPORTED reason rather than crashing. srv := New(&testCapturer{}, &StubInputInjector{}, "") srv.SetDisableAuth(true) addr := netip.MustParseAddrPort("127.0.0.1:0") network := netip.MustParsePrefix("127.0.0.0/8") require.NoError(t, srv.Start(t.Context(), addr, network)) srv.localAddr = netip.MustParseAddr("10.99.99.1") // Force vmgr to nil regardless of platform so the test is deterministic. srv.vmgr = nil t.Cleanup(func() { _ = srv.Stop() }) conn, err := net.Dial("tcp", srv.listener.Addr().String()) require.NoError(t, err) defer conn.Close() require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second))) // ModeSession with no username/JWT, so we exit on the vmgr==nil branch // before username validation runs. header := []byte{ModeSession, 0, 0, 0, 0} _, err = conn.Write(header) require.NoError(t, err) var version [12]byte _, err = io.ReadFull(conn, version[:]) require.NoError(t, err) _, err = conn.Write(version[:]) require.NoError(t, err) var numTypes [1]byte _, err = io.ReadFull(conn, numTypes[:]) require.NoError(t, err) assert.Equal(t, byte(0), numTypes[0]) var reasonLen [4]byte _, err = io.ReadFull(conn, reasonLen[:]) require.NoError(t, err) reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:])) _, err = io.ReadFull(conn, reason) require.NoError(t, err) assert.Contains(t, string(reason), RejectCodeUnsupportedOS) }