Replace VNC JWT auth with a Noise_IK handshake bound to ACL-pushed pubkeys

This commit is contained in:
Viktor Liu
2026-05-21 16:49:47 +02:00
parent 2f4ddf0796
commit 3d3055dc7f
36 changed files with 2014 additions and 1118 deletions

View File

@@ -39,12 +39,22 @@ var vncAgentCmd = &cobra.Command{
if token == "" {
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
}
// Drop the token from our process environment so any child the
// agent spawns does not inherit it, and casual debugging tools
// that dump /proc/<pid>/environ (or the Windows equivalent) on a
// running agent don't surface the loopback shared secret.
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
}
capturer, injector, err := newAgentResources()
if err != nil {
return err
}
srv := vncserver.New(capturer, injector)
// The per-user agent listens only on loopback and is gated by an
// agent token shared with the daemon, so no X25519 identity key
// is needed; auth is disabled at the RFB layer.
srv := vncserver.New(capturer, injector, nil)
srv.SetDisableAuth(true)
srv.SetAgentToken(token)

View File

@@ -1064,7 +1064,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
if err := e.updateVNC(); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}

View File

@@ -66,8 +66,7 @@ func (e *Engine) cleanupVNCPortRedirection() error {
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
// sshConf provides the JWT identity provider config (shared with SSH).
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
func (e *Engine) updateVNC() error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
@@ -81,15 +80,13 @@ func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
}
if e.vncSrv != nil {
// Update JWT config on existing server in case management sent new config.
e.updateVNCServerJWT(sshConf)
return nil
}
return e.startVNCServer(sshConf)
return e.startVNCServer()
}
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
func (e *Engine) startVNCServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
@@ -102,7 +99,7 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
netbirdIP := e.wgInterface.Address().IP
srv := vncserver.New(capturer, injector)
srv := vncserver.New(capturer, injector, e.config.WgPrivateKey[:])
if e.clientMetrics != nil {
srv.SetSessionRecorder(func(t vncserver.SessionTick) {
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
@@ -122,20 +119,6 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
srv.SetServiceMode(true)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
srv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
srv.SetNetstackNet(netstackNet)
}
@@ -165,35 +148,6 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
return nil
}
// updateVNCServerJWT configures the JWT validation for the VNC server using
// the same JWT config as SSH (same identity provider).
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
protoJWT := sshConf.GetJwtConfig()
if protoJWT == nil {
return
}
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
@@ -221,10 +175,28 @@ func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
machineUsers[osUser] = indexes.GetIndexes()
}
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
for _, e := range vncAuth.GetSessionPubKeys() {
pub := e.GetPubKey()
if len(pub) != 32 {
log.Warnf("VNC session pubkey wrong length %d", len(pub))
continue
}
hash := e.GetUserIdHash()
if len(hash) != 16 {
log.Warnf("VNC session user id hash wrong length %d", len(hash))
continue
}
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
PubKey: pub,
UserIDHash: sshuserhash.UserIDHash(hash),
})
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
UserIDClaim: vncAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
SessionPubKeys: sessionPubKeys,
})
}

View File

@@ -8,7 +8,7 @@ import (
type vncServer interface{}
func (e *Engine) updateVNC(_ *mgmProto.SSHConfig) error { return nil }
func (e *Engine) updateVNC() error { return nil }
func (e *Engine) updateVNCServerAuth(_ *mgmProto.VNCAuth) {
// no-op on platforms without a VNC server

View File

@@ -2107,7 +2107,9 @@ type VNCSessionInfo struct {
RemoteAddress string `protobuf:"bytes,1,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
Mode string `protobuf:"bytes,2,opt,name=mode,proto3" json:"mode,omitempty"`
Username string `protobuf:"bytes,3,opt,name=username,proto3" json:"username,omitempty"`
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
UserID string `protobuf:"bytes,4,opt,name=userID,proto3" json:"userID,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2163,9 +2165,9 @@ func (x *VNCSessionInfo) GetUsername() string {
return ""
}
func (x *VNCSessionInfo) GetJwtUsername() string {
func (x *VNCSessionInfo) GetUserID() string {
if x != nil {
return x.JwtUsername
return x.UserID
}
return ""
}
@@ -6582,12 +6584,12 @@ const file_daemon_proto_rawDesc = "" +
"\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
"\x0eSSHServerState\x12\x18\n" +
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\x88\x01\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"~\n" +
"\x0eVNCSessionInfo\x12$\n" +
"\rremoteAddress\x18\x01 \x01(\tR\rremoteAddress\x12\x12\n" +
"\x04mode\x18\x02 \x01(\tR\x04mode\x12\x1a\n" +
"\busername\x18\x03 \x01(\tR\busername\x12 \n" +
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
"\busername\x18\x03 \x01(\tR\busername\x12\x16\n" +
"\x06userID\x18\x04 \x01(\tR\x06userID\"^\n" +
"\x0eVNCServerState\x12\x18\n" +
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.VNCSessionInfoR\bsessions\"\xef\x04\n" +

View File

@@ -403,7 +403,9 @@ message VNCSessionInfo {
string remoteAddress = 1;
string mode = 2;
string username = 3;
string jwtUsername = 4;
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
string userID = 4;
}
// VNCServerState contains the latest state of the VNC server

View File

@@ -1199,7 +1199,7 @@ func (s *Server) getVNCServerState() *proto.VNCServerState {
RemoteAddress: sess.RemoteAddress,
Mode: sess.Mode,
Username: sess.Username,
JwtUsername: sess.JWTUsername,
UserID: sess.UserID,
})
}
return &proto.VNCServerState{

View File

@@ -15,13 +15,16 @@ const (
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
sessionPubKeyLen = 32
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
)
// Authorizer handles SSH fine-grained access control authorization
@@ -35,6 +38,12 @@ type Authorizer struct {
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// sessionPubKeys maps an X25519 static public key (as map-safe
// array) to the hashed user identity that key authenticates as.
// Populated from management's temporary-access flow; used by VNC to
// authenticate via the Noise_IK handshake.
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
// mu protects the list of users
mu sync.RWMutex
}
@@ -50,13 +59,25 @@ type Config struct {
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
// user identities. Populated for VNC; ignored on the SSH side.
SessionPubKeys []SessionPubKey
}
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
// static public key plus the hashed user identity it authenticates as.
type SessionPubKey struct {
PubKey []byte
UserIDHash sshuserhash.UserIDHash
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
}
return a
@@ -72,6 +93,7 @@ func (a *Authorizer) Update(config *Config) {
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
log.Info("SSH authorization cleared")
return
}
@@ -94,8 +116,19 @@ func (a *Authorizer) Update(config *Config) {
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
for _, e := range config.SessionPubKeys {
if len(e.PubKey) != sessionPubKeyLen {
continue
}
var key [sessionPubKeyLen]byte
copy(key[:], e.PubKey)
sessionPubKeys[key] = e.UserIDHash
}
a.sessionPubKeys = sessionPubKeys
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
}
// Authorize validates if a user is authorized to login as the specified OS user.
@@ -155,6 +188,38 @@ func (a *Authorizer) GetUserIDClaim() string {
return a.userIDClaim
}
// LookupSessionKey resolves a Noise-verified static public key to the
// hashed user identity registered with it. Fails closed when the key is
// unknown.
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
var zero sshuserhash.UserIDHash
if len(pubKey) != sessionPubKeyLen {
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
hash, ok := a.sessionPubKeys[key]
a.mu.RUnlock()
if !ok {
return zero, ErrSessionKeyNotKnown
}
return hash, nil
}
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
// key. Mirrors Authorize but skips the JWT-hash step since the key has
// already been verified and the user identity hash is in hand.
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
a.mu.RLock()
defer a.mu.RUnlock()
userIndex, found := a.findUserIndex(userIDHash)
if !found {
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping("session", osUsername, userIndex)
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {

View File

@@ -1,6 +1,7 @@
package auth
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -610,3 +611,61 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
pub := bytesRepeat(0x11, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{
AuthorizedUsers: []sshauth.UserIDHash{userHash},
MachineUsers: map[string][]uint32{Wildcard: {0}},
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
})
got, err := a.LookupSessionKey(pub)
require.NoError(t, err)
assert.Equal(t, userHash, got)
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
}
}
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
a := NewAuthorizer()
a.Update(&Config{})
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
}
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
a := NewAuthorizer()
_, err := a.LookupSessionKey([]byte("short"))
require.Error(t, err)
}
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
pub := bytesRepeat(0x33, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
if _, err := a.LookupSessionKey(pub); err != nil {
t.Fatalf("setup lookup: %v", err)
}
a.Update(&Config{})
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
}
}
func bytesRepeat(b byte, n int) []byte {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return out
}

View File

@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
}
}
// generateS4UUserToken creates a Windows token using S4U authentication.
// This is the same approach OpenSSH for Windows uses for public key authentication.
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)

View File

@@ -551,7 +551,27 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
maxTokenAge = DefaultJWTMaxTokenAge
}
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
@@ -582,7 +602,27 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
}
func extractUserID(token *gojwt.Token) string {
return jwt.UserIDFromToken(token)
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {

View File

@@ -135,7 +135,7 @@ type VNCSessionOutput struct {
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Mode string `json:"mode" yaml:"mode"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
}
type VNCServerStateOutput struct {
@@ -296,7 +296,7 @@ func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
RemoteAddress: sess.GetRemoteAddress(),
Mode: sess.GetMode(),
Username: sess.GetUsername(),
JWTUsername: sess.GetJwtUsername(),
UserID: sess.GetUserID(),
})
}
return VNCServerStateOutput{
@@ -583,9 +583,9 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
if showSSHSessions && vncSessionCount > 0 {
for _, sess := range o.VNCServerState.Sessions {
var line string
if sess.JWTUsername != "" {
if sess.UserID != "" {
line = fmt.Sprintf("[%s@%s -> %s] mode=%s",
sess.JWTUsername, sess.RemoteAddress, sess.Username, sess.Mode)
sess.UserID, sess.RemoteAddress, sess.Username, sess.Mode)
} else {
line = fmt.Sprintf("[%s] mode=%s user=%s",
sess.RemoteAddress, sess.Mode, sess.Username)

View File

@@ -0,0 +1,431 @@
//go:build !js && !ios && !android
package server
import (
"encoding/binary"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/curve25519"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
// noiseTestServer starts a VNC server with a freshly generated identity
// key and returns the listener address, the server, and the server's
// static public key for client-side handshake setup.
func noiseTestServer(t *testing.T) (net.Addr, *Server, []byte) {
t.Helper()
kp, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
srv := New(&testCapturer{}, &StubInputInjector{}, kp.Private)
srv.SetDisableAuth(false)
addr := netip.MustParseAddrPort("127.0.0.1:0")
network := netip.MustParsePrefix("127.0.0.0/8")
require.NoError(t, srv.Start(t.Context(), addr, network))
srv.localAddr = netip.MustParseAddr("10.99.99.1")
t.Cleanup(func() { _ = srv.Stop() })
return srv.listener.Addr(), srv, kp.Public
}
// registerSessionKey enrolls a fresh X25519 keypair under the given user
// ID into the server's authorizer with the requested OS-user wildcard
// mapping. Returns the keypair so the test can drive the handshake.
func registerSessionKey(t *testing.T, srv *Server, userID string) noise.DHKey {
t.Helper()
kp, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
userHash, err := sshuserhash.HashUserID(userID)
require.NoError(t, err)
srv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
MachineUsers: map[string][]uint32{sshauth.Wildcard: {0}},
SessionPubKeys: []sshauth.SessionPubKey{
{PubKey: kp.Public, UserIDHash: userHash},
},
})
return kp
}
// writeHeaderPrefix writes the mode + zero-length-username prefix that
// precedes the optional Noise handshake in the NetBird VNC header.
func writeHeaderPrefix(t *testing.T, conn net.Conn, mode byte) {
t.Helper()
prefix := []byte{mode, 0, 0}
_, err := conn.Write(prefix)
require.NoError(t, err)
}
// writeHeaderTail writes the sessionID/width/height fields that follow
// either the Noise msg2 (auth path) or the prefix alone (no-auth path).
func writeHeaderTail(t *testing.T, conn net.Conn) {
t.Helper()
tail := make([]byte, 8)
_, err := conn.Write(tail)
require.NoError(t, err)
}
// performInitiator drives the initiator side of Noise_IK against the
// server's identity public key, returns the resulting state. The Noise
// msg2 produced by the server is read and consumed.
func performInitiator(t *testing.T, conn net.Conn, clientKey noise.DHKey, serverPub []byte) {
t.Helper()
state, err := noise.NewHandshakeState(noise.Config{
CipherSuite: vncNoiseSuite,
Pattern: noise.HandshakeIK,
Initiator: true,
StaticKeypair: clientKey,
PeerStatic: serverPub,
})
require.NoError(t, err)
msg1, _, _, err := state.WriteMessage(nil, nil)
require.NoError(t, err)
require.Equal(t, noiseInitiatorMsgLen, len(msg1))
_, err = conn.Write(append([]byte("NBV3"), msg1...))
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
msg2 := make([]byte, noiseResponderMsgLen)
_, err = io.ReadFull(conn, msg2)
require.NoError(t, err)
_, _, _, err = state.ReadMessage(nil, msg2)
require.NoError(t, err, "server responder message must decrypt with the correct peer static")
}
// readRFBFailure consumes the RFB version exchange and returns the
// security-failure reason string. Fails the test if the server did not
// send a failure (i.e. produced a non-zero security-types list).
func readRFBFailure(t *testing.T, conn net.Conn) string {
t.Helper()
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
var ver [12]byte
_, err := io.ReadFull(conn, ver[:])
require.NoError(t, err)
require.Equal(t, "RFB 003.008\n", string(ver[:]))
_, err = conn.Write(ver[:])
require.NoError(t, err)
var n [1]byte
_, err = io.ReadFull(conn, n[:])
require.NoError(t, err)
require.Equal(t, byte(0), n[0], "expected security-failure (0 types)")
var rl [4]byte
_, err = io.ReadFull(conn, rl[:])
require.NoError(t, err)
reason := make([]byte, binary.BigEndian.Uint32(rl[:]))
_, err = io.ReadFull(conn, reason)
require.NoError(t, err)
return string(reason)
}
// readRFBGreetingNoFailure asserts the server proceeded past auth: it
// must offer at least one security type rather than a 0 failure.
func readRFBGreetingNoFailure(t *testing.T, conn net.Conn) {
t.Helper()
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
var ver [12]byte
_, err := io.ReadFull(conn, ver[:])
require.NoError(t, err)
require.Equal(t, "RFB 003.008\n", string(ver[:]))
_, err = conn.Write(ver[:])
require.NoError(t, err)
var n [1]byte
_, err = io.ReadFull(conn, n[:])
require.NoError(t, err)
require.NotEqual(t, byte(0), n[0], "server must offer security types after a valid handshake")
}
// TestNoise_RegisteredKey_AccessGranted exercises the happy path: a
// session key enrolled in the authorizer completes a Noise_IK handshake
// and the server proceeds to the RFB greeting.
func TestNoise_RegisteredKey_AccessGranted(t *testing.T) {
addr, srv, serverPub := noiseTestServer(t)
clientKey := registerSessionKey(t, srv, "alice@example")
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
performInitiator(t, conn, clientKey, serverPub)
writeHeaderTail(t, conn)
readRFBGreetingNoFailure(t, conn)
}
// TestNoise_UnregisteredClientStatic_Rejected proves the authorizer is
// consulted: a syntactically-valid handshake from a key the server has
// never been told about must be rejected fail-closed.
func TestNoise_UnregisteredClientStatic_Rejected(t *testing.T) {
addr, _, serverPub := noiseTestServer(t)
// Auth is enabled but the authorizer was not updated, so the lookup
// path returns ErrSessionKeyNotKnown.
attackerKey, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
performInitiator(t, conn, attackerKey, serverPub)
writeHeaderTail(t, conn)
reason := readRFBFailure(t, conn)
assert.Contains(t, reason, RejectCodeAuthForbidden)
assert.Contains(t, reason, "session pubkey not registered")
}
// TestNoise_WrongServerStatic_HandshakeFails proves the server's
// identity is bound into the handshake: an initiator using the wrong
// peer static encrypts msg1 under keys the real server can't derive, so
// the server fails the handshake and closes without RFB output.
func TestNoise_WrongServerStatic_HandshakeFails(t *testing.T) {
addr, srv, _ := noiseTestServer(t)
clientKey := registerSessionKey(t, srv, "alice@example")
bogusServerKey, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
state, err := noise.NewHandshakeState(noise.Config{
CipherSuite: vncNoiseSuite,
Pattern: noise.HandshakeIK,
Initiator: true,
StaticKeypair: clientKey,
PeerStatic: bogusServerKey.Public,
})
require.NoError(t, err)
msg1, _, _, err := state.WriteMessage(nil, nil)
require.NoError(t, err)
_, err = conn.Write(append([]byte("NBV3"), msg1...))
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
var b [1]byte
_, err = io.ReadFull(conn, b[:])
require.Error(t, err, "server must close without RFB greeting when msg1 is sealed for a different server identity")
}
// TestNoise_MalformedMsg1_ClosesConnection covers the case where the
// magic prefix is correct but the following 96 bytes are random: the
// noise library fails ReadMessage and the server closes silently.
func TestNoise_MalformedMsg1_ClosesConnection(t *testing.T) {
addr, _, _ := noiseTestServer(t)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
junk := make([]byte, noiseInitiatorMsgLen)
for i := range junk {
junk[i] = byte(i)
}
_, err = conn.Write(append([]byte("NBV3"), junk...))
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
var b [1]byte
_, err = io.ReadFull(conn, b[:])
require.Error(t, err, "garbage msg1 must terminate the connection before any RFB output")
}
// TestNoise_TruncatedMsg1_ClosesConnection sends fewer than the 96
// bytes a Noise_IK msg1 must contain. The server's io.ReadFull short-
// reads and closes; no RFB greeting must leak.
func TestNoise_TruncatedMsg1_ClosesConnection(t *testing.T) {
addr, _, _ := noiseTestServer(t)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
writeHeaderPrefix(t, conn, ModeAttach)
_, err = conn.Write([]byte("NBV3"))
require.NoError(t, err)
_, err = conn.Write(make([]byte, 8))
require.NoError(t, err)
require.NoError(t, conn.Close())
// Re-dial just to confirm the listener is alive (the previous
// connection terminated server-side without affecting the listener).
probe, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
require.NoError(t, probe.Close())
}
// TestNoise_AuthEnabled_NoHandshake_Rejected proves that with auth on,
// a connection that skips the Noise prefix (older client / VNC client)
// is rejected with AUTH_FORBIDDEN: identity proof missing.
func TestNoise_AuthEnabled_NoHandshake_Rejected(t *testing.T) {
addr, _, _ := noiseTestServer(t)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
writeHeaderTail(t, conn)
reason := readRFBFailure(t, conn)
assert.Contains(t, reason, RejectCodeAuthForbidden)
assert.Contains(t, reason, "identity proof missing")
}
// TestNoise_RevokedKey_RejectedAfterAuthUpdate verifies the authorizer
// honors revocations: a key that worked before a UpdateVNCAuth call
// must stop working as soon as the new config omits it.
func TestNoise_RevokedKey_RejectedAfterAuthUpdate(t *testing.T) {
addr, srv, serverPub := noiseTestServer(t)
clientKey := registerSessionKey(t, srv, "alice@example")
// First connection succeeds.
conn1, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn1.Close()
writeHeaderPrefix(t, conn1, ModeAttach)
performInitiator(t, conn1, clientKey, serverPub)
writeHeaderTail(t, conn1)
readRFBGreetingNoFailure(t, conn1)
// Revoke by pushing a fresh config that drops the pubkey entry.
srv.UpdateVNCAuth(&sshauth.Config{})
// Same client, same Noise key, should now be denied.
conn2, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn2.Close()
writeHeaderPrefix(t, conn2, ModeAttach)
performInitiator(t, conn2, clientKey, serverPub)
writeHeaderTail(t, conn2)
reason := readRFBFailure(t, conn2)
assert.Contains(t, reason, RejectCodeAuthForbidden)
assert.Contains(t, reason, "session pubkey not registered")
}
// TestNoise_NoIdentityKey_FailsClosed ensures a server constructed
// without a static private key still rejects authenticated connections
// fail-closed; it must not silently accept the client.
func TestNoise_NoIdentityKey_FailsClosed(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(false)
addr := netip.MustParseAddrPort("127.0.0.1:0")
network := netip.MustParsePrefix("127.0.0.0/8")
require.NoError(t, srv.Start(t.Context(), addr, network))
srv.localAddr = netip.MustParseAddr("10.99.99.1")
t.Cleanup(func() { _ = srv.Stop() })
clientKey, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
fakeServerKey, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
conn, err := net.Dial("tcp", srv.listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
writeHeaderPrefix(t, conn, ModeAttach)
state, err := noise.NewHandshakeState(noise.Config{
CipherSuite: vncNoiseSuite,
Pattern: noise.HandshakeIK,
Initiator: true,
StaticKeypair: clientKey,
PeerStatic: fakeServerKey.Public,
})
require.NoError(t, err)
msg1, _, _, err := state.WriteMessage(nil, nil)
require.NoError(t, err)
_, err = conn.Write(append([]byte("NBV3"), msg1...))
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
var b [1]byte
_, err = io.ReadFull(conn, b[:])
require.Error(t, err, "server without identity key must not write the RFB greeting")
}
// TestNoise_DerivedIdentityPublicMatchesPrivate sanity-checks the
// derivation done in New(): the identityPublic must be Curve25519.
// Basepoint multiplied with identityKey.
func TestNoise_DerivedIdentityPublicMatchesPrivate(t *testing.T) {
priv := make([]byte, 32)
for i := range priv {
priv[i] = byte(i + 1)
}
srv := New(&testCapturer{}, &StubInputInjector{}, priv)
expected, err := curve25519.X25519(priv, curve25519.Basepoint)
require.NoError(t, err)
assert.Equal(t, expected, srv.identityPublic)
}
// TestNoise_SessionMode_OSUserCheckRunsAfterHandshake verifies that a
// successful Noise handshake doesn't bypass OS-user authorization: an
// authenticated key whose user index isn't mapped to the requested OS
// user must be rejected.
func TestNoise_SessionMode_OSUserCheckRunsAfterHandshake(t *testing.T) {
addr, srv, serverPub := noiseTestServer(t)
clientKey, err := noise.DH25519.GenerateKeypair(nil)
require.NoError(t, err)
userHash, err := sshuserhash.HashUserID("alice@example")
require.NoError(t, err)
// Map Alice only to "alice" OS user, not the wildcard.
srv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
MachineUsers: map[string][]uint32{"alice": {0}},
SessionPubKeys: []sshauth.SessionPubKey{
{PubKey: clientKey.Public, UserIDHash: userHash},
},
})
// Request session for "bob" — Noise succeeds, OS-user check denies.
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
bob := []byte("bob")
prefix := []byte{ModeSession, 0, byte(len(bob))}
prefix = append(prefix, bob...)
_, err = conn.Write(prefix)
require.NoError(t, err)
performInitiator(t, conn, clientKey, serverPub)
writeHeaderTail(t, conn)
reason := readRFBFailure(t, conn)
assert.Contains(t, reason, RejectCodeAuthForbidden)
assert.Contains(t, reason, "authorize OS user")
}

View File

@@ -3,6 +3,8 @@
package server
import (
"bufio"
"bytes"
"context"
"crypto/subtle"
"encoding/binary"
@@ -13,16 +15,15 @@ import (
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
"github.com/flynn/noise"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/tun/netstack"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Connection modes sent by the client in the session header.
@@ -35,11 +36,7 @@ const (
// stable so clients can branch on them without parsing free text.
// Format: "CODE: human message".
const (
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
RejectCodeAuthConfig = "AUTH_CONFIG"
RejectCodeSessionError = "SESSION_ERROR"
RejectCodeCapturerError = "CAPTURER_ERROR"
RejectCodeUnsupportedOS = "UNSUPPORTED"
@@ -56,6 +53,21 @@ const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
// enough to coalesce bursty multi-session requests. 16 ms ~= 60 fps.
const freshWindow = 16 * time.Millisecond
// maxConcurrentVNCConns caps in-flight VNC connections. Each accepted
// connection consumes a handler goroutine, a tracking entry, and (after
// handshake) capturer/encoder resources, so an unauthenticated peer that
// dials in a tight loop could otherwise grow memory without bound. The
// limit covers the entire accept→handshake→session window; a slot is
// released only when the handler returns.
const maxConcurrentVNCConns = 64
// maxFramebufferDim caps the screen dimensions accepted from a capturer.
// RFB serialises width/height as u16, and the encoder allocates per-frame
// buffers proportional to width*height*4. 8192 keeps width*height*4 well
// under 2^31 so int math doesn't overflow on 32-bit builds, and is large
// enough to cover real-world multi-monitor desktops.
const maxFramebufferDim = 8192
// ScreenCapturer grabs desktop frames for the VNC server.
type ScreenCapturer interface {
// Width returns the current screen width in pixels.
@@ -120,26 +132,22 @@ type InputInjector interface {
TypeText(text string)
}
// JWTConfig holds JWT validation configuration for VNC auth.
type JWTConfig struct {
Issuer string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// connectionHeader is sent by the client before the RFB handshake to specify
// the VNC session mode and authenticate.
type connectionHeader struct {
mode byte
username string
jwt string
// clientStatic is the client's static X25519 public key learned from
// the Noise handshake. Populated when identityVerified is true.
clientStatic []byte
// sessionID is the Windows session ID; 0 selects the console session.
sessionID uint32
// width and height request the virtual display geometry for session mode.
// Zero means use the default.
width uint16
height uint16
// identityVerified is true when the Noise_IK handshake completed.
identityVerified bool
}
// Server is the embedded VNC server that listens on the WireGuard interface.
@@ -170,13 +178,16 @@ type Server struct {
ctx context.Context
cancel context.CancelFunc
vmgr virtualSessionManager
jwtConfig *JWTConfig
jwtValidator *nbjwt.Validator
jwtExtractor *nbjwt.ClaimsExtractor
authorizer *sshauth.Authorizer
netstackNet *netstack.Net
authorizer *sshauth.Authorizer
netstackNet *netstack.Net
// agentToken holds the raw token bytes for agent-mode auth.
agentToken []byte
// identityKey is the daemon's static X25519 private key used in the
// Noise_IK handshake. Nil disables the handshake.
identityKey []byte
// identityPublic is the matching X25519 public key, derived once at
// construction to avoid recomputing per handshake.
identityPublic []byte
sessionsMu sync.Mutex
sessionSeq uint64
@@ -188,6 +199,17 @@ type Server struct {
// closeActiveSessions iterates this set so Stop() can interrupt
// handshaking peers, not just post-handshake sessions.
acceptedConns map[net.Conn]struct{}
// connAuth holds the verified Noise_IK identity tied to each accepted
// connection so a later UpdateVNCAuth call can revoke live sessions
// whose authorization no longer holds. Populated by registerConnAuth
// once authenticateSession succeeds; absent entries (e.g. disableAuth
// or pre-handshake conns) are skipped at revocation time.
connAuth map[net.Conn]connAuthInfo
// connSem caps concurrent accepted connections (handshake + session).
// Buffered with maxConcurrentVNCConns slots; accept loops try-acquire
// before spawning a handler and release on handler return.
connSem chan struct{}
// sessionRecorder, when non-nil, receives a SessionTick periodically
// during each VNC session and on session close. The engine wires
@@ -195,12 +217,24 @@ type Server struct {
sessionRecorder func(SessionTick)
}
// connAuthInfo captures the Noise_IK-verified identity bound to a live
// connection so policy updates can re-check it and close sessions whose
// authorization was revoked. clientStatic is empty when auth was disabled
// for this connection, which signals that revocation does not apply.
type connAuthInfo struct {
clientStatic []byte
mode byte
username string
}
// ActiveSessionInfo describes a currently connected VNC client.
type ActiveSessionInfo struct {
RemoteAddress string
Mode string
Username string
JWTUsername string
// UserID is the authenticated session identity (hashed user ID from
// the Noise_IK static-key registration), empty when auth is disabled.
UserID string
}
// vncSession provides capturer and injector for a virtual display session.
@@ -220,19 +254,31 @@ type virtualSessionManager interface {
StopAll()
}
// New creates a VNC server with the given screen capturer and input injector.
// Authentication uses a JWT supplied by the client in the connection
// header; the protocol-level VNC password scheme is not supported.
func New(capturer ScreenCapturer, injector InputInjector) *Server {
return &Server{
// New creates a VNC server. identityKey is the 32-byte X25519 private
// key used by the daemon in the Noise_IK handshake; nil disables auth.
// The protocol-level VNC password scheme is not supported.
func New(capturer ScreenCapturer, injector InputInjector, identityKey []byte) *Server {
s := &Server{
capturer: capturer,
injector: injector,
identityKey: identityKey,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
sessions: make(map[uint64]ActiveSessionInfo),
sessionConns: make(map[uint64]net.Conn),
acceptedConns: make(map[net.Conn]struct{}),
connAuth: make(map[net.Conn]connAuthInfo),
connSem: make(chan struct{}, maxConcurrentVNCConns),
}
if len(identityKey) == 32 {
pub, err := curve25519.X25519(identityKey, curve25519.Basepoint)
if err == nil {
s.identityPublic = pub
} else {
s.log.Warnf("derive identity public key: %v", err)
}
}
return s
}
// ActiveSessions returns a snapshot of currently connected VNC clients.
@@ -292,9 +338,75 @@ func (s *Server) trackConn(c net.Conn) {
func (s *Server) untrackConn(c net.Conn) {
s.sessionsMu.Lock()
delete(s.acceptedConns, c)
delete(s.connAuth, c)
s.sessionsMu.Unlock()
}
// registerConnAuth records the verified Noise_IK identity for a live
// connection so UpdateVNCAuth can later revoke it if policy changes.
// No-op when auth is disabled (e.g. agent-mode loopback connections).
func (s *Server) registerConnAuth(c net.Conn, header *connectionHeader) {
if s.disableAuth || header == nil || len(header.clientStatic) != 32 {
return
}
s.sessionsMu.Lock()
s.connAuth[c] = connAuthInfo{
clientStatic: append([]byte(nil), header.clientStatic...),
mode: header.mode,
username: header.username,
}
s.sessionsMu.Unlock()
}
// tryAcquireConnSlot returns true when a connection slot was successfully
// reserved. Releases must pair with releaseConnSlot. Returns false when
// the cap is already saturated; callers must close the connection.
func (s *Server) tryAcquireConnSlot() bool {
select {
case s.connSem <- struct{}{}:
return true
default:
return false
}
}
func (s *Server) releaseConnSlot() {
select {
case <-s.connSem:
default:
}
}
// revokeUnauthorizedSessions closes every live connection whose Noise-
// verified identity no longer authenticates under the current authorizer
// configuration. Called by UpdateVNCAuth after the new policy is applied.
func (s *Server) revokeUnauthorizedSessions() {
if s.disableAuth {
return
}
s.sessionsMu.Lock()
victims := make([]net.Conn, 0)
for c, info := range s.connAuth {
if len(info.clientStatic) != 32 {
continue
}
hdr := &connectionHeader{
identityVerified: true,
clientStatic: info.clientStatic,
mode: info.mode,
username: info.username,
}
if _, err := s.authenticateSession(hdr); err != nil {
victims = append(victims, c)
s.log.Infof("revoking VNC session from %s: %v", c.RemoteAddr(), err)
}
}
s.sessionsMu.Unlock()
for _, c := range victims {
_ = c.Close()
}
}
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
func (s *Server) SetServiceMode(enabled bool) {
s.serviceMode = enabled
@@ -308,16 +420,6 @@ func (s *Server) SetSessionRecorder(recorder func(SessionTick)) {
s.sessionRecorder = recorder
}
// SetJWTConfig configures JWT authentication for VNC connections.
// Pass nil to disable JWT (public mode).
func (s *Server) SetJWTConfig(config *JWTConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtConfig = config
s.jwtValidator = nil
s.jwtExtractor = nil
}
// SetDisableAuth disables authentication entirely.
func (s *Server) SetDisableAuth(disable bool) {
s.disableAuth = disable
@@ -346,13 +448,14 @@ func (s *Server) SetNetstackNet(n *netstack.Net) {
s.netstackNet = n
}
// UpdateVNCAuth updates the fine-grained authorization configuration.
// UpdateVNCAuth updates the fine-grained authorization configuration and
// closes any live session whose identity no longer authenticates under
// the new policy. Revocation is event-driven: there is no periodic
// re-check, so a session stays open until either the next UpdateVNCAuth
// call or normal disconnect.
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
s.revokeUnauthorizedSessions()
}
// Start begins listening for VNC connections on the given address.
@@ -463,9 +566,15 @@ func (s *Server) acceptLoop() {
continue
}
if !s.tryAcquireConnSlot() {
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
_ = conn.Close()
continue
}
enableTCPKeepAlive(conn, s.log)
s.trackConn(conn)
go func(c net.Conn) {
defer s.releaseConnSlot()
defer s.untrackConn(c)
s.handleConnection(c)
}(conn)
@@ -565,16 +674,17 @@ func (s *Server) handleConnection(conn net.Conn) {
if !s.verifyAgentToken(conn, connLog) {
return
}
header, err := readConnectionHeader(conn)
header, err := s.readConnectionHeader(conn)
if err != nil {
connLog.Warnf("read connection header: %v", err)
conn.Close()
return
}
connLog, jwtUserID, ok := s.authorizeJWT(conn, header, connLog)
connLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
if !ok {
return
}
s.registerConnAuth(conn, header)
capturer, injector, sessionCleanup, ok := s.acquireSessionResources(conn, header, &connLog)
if !ok {
@@ -586,7 +696,7 @@ func (s *Server) handleConnection(conn net.Conn) {
RemoteAddress: conn.RemoteAddr().String(),
Mode: modeString(header.mode),
Username: header.username,
JWTUsername: jwtUserID,
UserID: sessionUserID,
}, conn)
defer s.removeSession(sessionID)
@@ -596,13 +706,20 @@ func (s *Server) handleConnection(conn net.Conn) {
return
}
w, h := capturer.Width(), capturer.Height()
if w <= 0 || h <= 0 || w > maxFramebufferDim || h > maxFramebufferDim {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("framebuffer dimensions out of range: %dx%d", w, h)))
connLog.Warnf("rejecting session: framebuffer %dx%d outside [1, %d]", w, h, maxFramebufferDim)
return
}
conn = newMetricsConn(conn, s.sessionRecorder)
sess := &session{
conn: conn,
capturer: capturer,
injector: injector,
serverW: capturer.Width(),
serverH: capturer.Height(),
serverW: w,
serverH: h,
log: connLog,
}
sess.serve()
@@ -615,25 +732,6 @@ func codeMessage(code, msg string) string {
return code + ": " + msg
}
// jwtErrorCode maps a JWT auth error to a stable reject code.
func jwtErrorCode(err error) string {
if err == nil {
return RejectCodeJWTInvalid
}
if errors.Is(err, nbjwt.ErrTokenExpired) {
return RejectCodeJWTExpired
}
msg := err.Error()
switch {
case strings.Contains(msg, "JWT required but not provided"):
return RejectCodeJWTMissing
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
return RejectCodeAuthForbidden
default:
return RejectCodeJWTInvalid
}
}
// rejectConnection sends a minimal RFB handshake with a security failure
// reason, so VNC clients display the error message instead of a generic
// "unexpected disconnect."
@@ -658,105 +756,57 @@ func rejectConnection(conn net.Conn, reason string) {
_, _ = conn.Write(buf)
}
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
// authenticateJWT validates the JWT from the connection header and checks
// authorization. For attach mode, just checks membership in the authorized
// user list. For session mode, additionally validates the OS user mapping.
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
if header.jwt == "" {
return "", fmt.Errorf("JWT required but not provided")
// authenticateSession resolves the Noise-verified client static public
// key to a hashed user identity via the authorizer, and checks OS-user
// mapping for session mode. Returns the hashed user identity on success.
func (s *Server) authenticateSession(header *connectionHeader) (string, error) {
if !header.identityVerified {
return "", fmt.Errorf("identity proof missing")
}
if len(header.clientStatic) != 32 {
return "", fmt.Errorf("client static key missing")
}
s.mu.Lock()
if err := s.ensureJWTValidator(); err != nil {
s.mu.Unlock()
return "", fmt.Errorf("initialize JWT validator: %w", err)
}
validator := s.jwtValidator
extractor := s.jwtExtractor
s.mu.Unlock()
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
userIDHash, err := s.authorizer.LookupSessionKey(header.clientStatic)
if err != nil {
return "", fmt.Errorf("validate JWT: %w", err)
return "", fmt.Errorf("lookup session pubkey: %w", err)
}
if err := s.checkTokenAge(token); err != nil {
return "", err
osUser := "*"
if header.mode == ModeSession {
osUser = header.username
}
userAuth, err := extractor.ToUserAuth(token)
if err != nil {
return "", fmt.Errorf("extract user from JWT: %w", err)
if _, err := s.authorizer.AuthorizeOSUserBySessionKey(userIDHash, osUser); err != nil {
return "", fmt.Errorf("authorize OS user %q: %w", osUser, err)
}
if userAuth.UserId == "" {
return "", fmt.Errorf("JWT has no user ID")
}
switch header.mode {
case ModeSession:
// Session mode: check user + OS username mapping.
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
}
default:
// Attach mode: just check user is in the authorized list (wildcard OS user).
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
return "", fmt.Errorf("user not authorized for VNC: %w", err)
}
}
return userAuth.UserId, nil
return userIDHash.String(), nil
}
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
func (s *Server) ensureJWTValidator() error {
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
if s.jwtConfig == nil {
return fmt.Errorf("no JWT config")
}
var vncIdentityMagic = []byte("NBV3")
// Enable IdP key refresh so JWKS rotations don't latch the validator
// off until daemon restart.
s.jwtValidator = nbjwt.NewValidator(
s.jwtConfig.Issuer,
s.jwtConfig.Audiences,
s.jwtConfig.KeysLocation,
true,
)
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
// msg1 = e(32) + s_AEAD(32+16) + payload_AEAD(0+16) = 96 bytes
// msg2 = e(32) + payload_AEAD(0+16) = 48 bytes
const (
noiseInitiatorMsgLen = 96
noiseResponderMsgLen = 48
)
var opts []nbjwt.ClaimsExtractorOption
if len(s.jwtConfig.Audiences) > 0 {
opts = append(opts, nbjwt.WithAudience(s.jwtConfig.Audiences[0]))
}
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
opts = append(opts, nbjwt.WithUserIDClaim(claim))
}
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
// vncNoiseSuite pins the cipher suite for the VNC handshake. Changing
// it requires bumping vncIdentityMagic so old clients fail closed.
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
return nil
}
func (s *Server) checkTokenAge(token *gojwt.Token) error {
maxAge := defaultJWTMaxTokenAge
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
maxAge = int(s.jwtConfig.MaxTokenAge)
}
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
}
// readConnectionHeader reads the NetBird VNC session header from the connection.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
// readConnectionHeader reads the NetBird VNC session header. Format:
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
// [mode: 1] [username_len: 2 BE] [username: N]
// [opt magic "NBV3": 4] [noise_msg1: 96]
// (server writes [noise_msg2: 48] here when the magic is present)
// [session_id: 4 BE] [width: 2 BE] [height: 2 BE]
//
// Uses a short timeout: our WASM proxy sends the header immediately after
// connecting. Standard VNC clients don't send anything first (server speaks
// first in RFB), so they time out and get the default attach mode.
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
// Standard VNC clients don't speak first, so they time out on the first
// read and fall through to attach mode (which auth still rejects when
// no Noise handshake completed).
func (s *Server) readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
@@ -764,11 +814,9 @@ func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
var hdr [3]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
// Timeout or error: assume no header, use attach mode.
return &connectionHeader{mode: ModeAttach}, nil
}
// Restore a longer deadline for reading variable-length fields.
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
@@ -788,48 +836,93 @@ func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
username = string(buf)
}
// Read JWT token length and data.
var jwtLenBuf [2]byte
var jwtToken string
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
if jwtLen >= 8192 {
return nil, fmt.Errorf("jwt too long: %d (max 8191)", jwtLen)
}
if jwtLen > 0 {
buf := make([]byte, jwtLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read JWT: %w", err)
}
jwtToken = string(buf)
}
br := bufio.NewReader(conn)
clientStatic, identityVerified, err := s.maybeRunNoiseHandshake(conn, br)
if err != nil {
return nil, err
}
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
var sessionID uint32
var sidBuf [4]byte
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
if _, err := io.ReadFull(br, sidBuf[:]); err == nil {
sessionID = binary.BigEndian.Uint32(sidBuf[:])
}
// Read optional requested viewport size (2x uint16 BE). Missing = 0 (default).
var width, height uint16
var geomBuf [4]byte
if _, err := io.ReadFull(conn, geomBuf[:]); err == nil {
if _, err := io.ReadFull(br, geomBuf[:]); err == nil {
width = binary.BigEndian.Uint16(geomBuf[0:2])
height = binary.BigEndian.Uint16(geomBuf[2:4])
}
return &connectionHeader{
mode: mode,
username: username,
jwt: jwtToken,
sessionID: sessionID,
width: width,
height: height,
mode: mode,
username: username,
clientStatic: clientStatic,
sessionID: sessionID,
width: width,
height: height,
identityVerified: identityVerified,
}, nil
}
// maybeRunNoiseHandshake performs the responder side of a Noise_IK
// handshake when the client sends the v3 magic. Returns the client static
// public key learned from the handshake. Any handshake failure is fatal
// (fail closed).
func (s *Server) maybeRunNoiseHandshake(conn net.Conn, br *bufio.Reader) ([]byte, bool, error) {
peek, err := br.Peek(len(vncIdentityMagic))
if err != nil || !bytes.Equal(peek, vncIdentityMagic) {
return nil, false, nil
}
if _, err := br.Discard(len(vncIdentityMagic)); err != nil {
return nil, false, fmt.Errorf("discard identity magic: %w", err)
}
msg1 := make([]byte, noiseInitiatorMsgLen)
if _, err := io.ReadFull(br, msg1); err != nil {
return nil, false, fmt.Errorf("read noise msg1: %w", err)
}
// Agents on loopback authenticate via the agent token, not this
// handshake. Consume the replayed bytes and skip the response.
if s.disableAuth {
return nil, true, nil
}
if len(s.identityKey) != 32 || len(s.identityPublic) != 32 {
return nil, false, errors.New("identity key not configured")
}
state, err := noise.NewHandshakeState(noise.Config{
CipherSuite: vncNoiseSuite,
Pattern: noise.HandshakeIK,
Initiator: false,
StaticKeypair: noise.DHKey{Private: s.identityKey, Public: s.identityPublic},
})
if err != nil {
return nil, false, fmt.Errorf("noise responder init: %w", err)
}
if _, _, _, err := state.ReadMessage(nil, msg1); err != nil {
return nil, false, fmt.Errorf("noise read msg1: %w", err)
}
msg2, _, _, err := state.WriteMessage(nil, nil)
if err != nil {
return nil, false, fmt.Errorf("noise write msg2: %w", err)
}
if len(msg2) != noiseResponderMsgLen {
return nil, false, fmt.Errorf("noise responder produced %d bytes, expected %d", len(msg2), noiseResponderMsgLen)
}
if _, err := conn.Write(msg2); err != nil {
return nil, false, fmt.Errorf("write noise msg2: %w", err)
}
clientStatic := state.PeerStatic()
if len(clientStatic) != 32 {
return nil, false, errors.New("noise peer static missing")
}
return clientStatic, true, nil
}
// verifyAgentToken validates the agent token prefix when configured. Returns
// false when the token is invalid or unreadable; the connection is closed.
func (s *Server) verifyAgentToken(conn net.Conn, connLog *log.Entry) bool {
@@ -865,25 +958,34 @@ func (s *Server) verifyAgentToken(conn net.Conn, connLog *log.Entry) bool {
return true
}
// authorizeJWT performs JWT validation when auth is enabled. Returns the
// enriched log entry, jwt user ID (empty when auth disabled), and ok=false
// if the connection was rejected.
func (s *Server) authorizeJWT(conn net.Conn, header *connectionHeader, connLog *log.Entry) (*log.Entry, string, bool) {
// authorizeSession runs the Noise_IK handshake when auth is enabled.
// Returns the enriched log entry, user identity hash (empty when auth
// disabled), and ok=false if the connection was rejected.
func (s *Server) authorizeSession(conn net.Conn, header *connectionHeader, connLog *log.Entry) (*log.Entry, string, bool) {
if s.disableAuth {
return connLog, "", true
}
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return connLog, "", false
}
jwtUserID, err := s.authenticateJWT(header)
userID, err := s.authenticateSession(header)
if err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
connLog.Warnf("auth rejected: %v", err)
return connLog, "", false
}
return connLog.WithField("jwt_user", jwtUserID), jwtUserID, true
return connLog.WithFields(log.Fields{
"session_user": userID,
"session_key": sessionKeyFingerprint(header.clientStatic),
}), userID, true
}
// sessionKeyFingerprint returns a short hex fingerprint of a client
// static key for log correlation. Distinct VNC sessions of the same
// user end up with distinct fingerprints because each session mints a
// fresh keypair, so this lets an operator tell parallel sessions apart.
func sessionKeyFingerprint(clientStatic []byte) string {
if len(clientStatic) < 4 {
return ""
}
return hex.EncodeToString(clientStatic[:4])
}
// acquireSessionResources returns the capturer/injector to use for this

View File

@@ -47,10 +47,16 @@ func (s *Server) serviceAcceptLoop() {
continue
}
if !s.tryAcquireConnSlot() {
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
_ = conn.Close()
continue
}
enableTCPKeepAlive(conn, s.log)
conn = newMetricsConn(conn, s.sessionRecorder)
s.trackConn(conn)
go func(c net.Conn) {
defer s.releaseConnSlot()
defer s.untrackConn(c)
s.handleServiceConnectionDarwin(c, mgr)
}(conn)
@@ -69,7 +75,7 @@ func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentMa
tee := io.TeeReader(conn, &headerBuf)
teeConn := &darwinPrefixConn{Reader: tee, Conn: conn}
header, err := readConnectionHeader(teeConn)
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Debugf("read connection header: %v", err)
conn.Close()
@@ -77,17 +83,13 @@ func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentMa
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return
}
if _, err := s.authenticateJWT(header); err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
if _, err := s.authenticateSession(header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
connLog.Warnf("auth rejected: %v", err)
return
}
}
s.registerConnAuth(conn, header)
token, err := mgr.ensure(s.ctx)
if err != nil {

View File

@@ -9,7 +9,6 @@ import (
"io"
"net"
"net/netip"
"strings"
"testing"
"time"
@@ -26,14 +25,11 @@ func (t *testCapturer) Capture() (*image.RGBA, error) {
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
}
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
t.Helper()
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(disableAuth)
if jwtConfig != nil {
srv.SetJWTConfig(jwtConfig)
}
addr := netip.MustParseAddrPort("127.0.0.1:0")
network := netip.MustParsePrefix("127.0.0.0/8")
@@ -45,30 +41,28 @@ func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.
return srv.listener.Addr(), srv
}
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
addr, _ := startTestServer(t, false, nil)
func TestAuthEnabled_NoSessionAuth_RejectsConnection(t *testing.T) {
addr, _ := startTestServer(t, false)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
// Header with no Noise handshake. Auth-required servers must reject
// because no client static was authenticated.
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
header[0] = ModeAttach
_, err = conn.Write(header)
require.NoError(t, err)
// Server should send RFB version then security failure.
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
assert.Equal(t, "RFB 003.008\n", string(version[:]))
// Write client version to proceed through handshake.
_, err = conn.Write(version[:])
require.NoError(t, err)
// Read security types: 0 means failure, followed by reason.
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
@@ -81,18 +75,17 @@ func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
_, err = io.ReadFull(conn, reason)
require.NoError(t, err)
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
assert.Contains(t, string(reason), "identity proof missing", "rejection reason should mention missing identity proof")
}
func TestAuthDisabled_AllowsConnection(t *testing.T) {
addr, _ := startTestServer(t, true, nil)
addr, _ := startTestServer(t, true)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
header[0] = ModeAttach
_, err = conn.Write(header)
require.NoError(t, err)
@@ -114,70 +107,12 @@ func TestAuthDisabled_AllowsConnection(t *testing.T) {
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
}
// TestAuthEnabled_InvalidJWT_RejectedBeforeRFB confirms the VNC server itself
// (not just the JWT library) wires authentication into handleConnection. A
// well-formed JWT-shaped token must hit the server's validation path and be
// rejected with an AUTH_JWT_* reason, never reaching the RFB handshake.
func TestAuthEnabled_InvalidJWT_RejectedBeforeRFB(t *testing.T) {
addr, _ := startTestServer(t, false, &JWTConfig{
Issuer: "https://example.invalid",
KeysLocation: "https://example.invalid/.well-known/jwks.json",
Audiences: []string{"test"},
})
// Three-segment "JWT" with bogus base64. The server's authenticateJWT path
// must catch this regardless of the IdP being unreachable.
bogusJWT := "abc.def.ghi"
header := make([]byte, 3+2+len(bogusJWT)+4+4)
header[0] = ModeAttach
binary.BigEndian.PutUint16(header[1:3], 0) // username len
binary.BigEndian.PutUint16(header[3:5], uint16(len(bogusJWT)))
copy(header[5:5+len(bogusJWT)], bogusJWT)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
_, err = conn.Write(header)
require.NoError(t, err)
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
_, err = conn.Write(version[:])
require.NoError(t, err)
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
require.Equal(t, byte(0), numTypes[0], "must fail security negotiation")
var reasonLen [4]byte
_, err = io.ReadFull(conn, reasonLen[:])
require.NoError(t, err)
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
_, err = io.ReadFull(conn, reason)
require.NoError(t, err)
// The reason must carry one of the server's AUTH_JWT_* codes, proving
// the rejection came from authenticateJWT in handleConnection.
r := string(reason)
hasJWTReject := false
for _, code := range []string{RejectCodeJWTInvalid, RejectCodeJWTExpired, RejectCodeAuthForbidden} {
if strings.Contains(r, code) {
hasJWTReject = true
break
}
}
assert.True(t, hasJWTReject, "reason %q must include an AUTH_JWT_* code", r)
}
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
// content to a connection that fails source validation. Specifically, the
// server must close immediately and the client must see EOF before any RFB
// version greeting is written.
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(true)
addr := netip.MustParseAddrPort("127.0.0.1:0")
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
@@ -198,37 +133,6 @@ func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
}
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
addr, _ := startTestServer(t, false, &JWTConfig{
Issuer: "https://example.com",
KeysLocation: "https://example.com/.well-known/jwks.json",
Audiences: []string{"test"},
})
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header with empty JWT.
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
header[0] = ModeAttach
_, err = conn.Write(header)
require.NoError(t, err)
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
_, err = conn.Write(version[:])
require.NoError(t, err)
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
}
func TestIsAllowedSource(t *testing.T) {
tests := []struct {
name string
@@ -289,7 +193,7 @@ func TestIsAllowedSource(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.localAddr = tc.localAddr
srv.network = tc.network
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
@@ -298,7 +202,7 @@ func TestIsAllowedSource(t *testing.T) {
}
func TestStart_InvalidNetworkRejected(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
addr := netip.MustParseAddrPort("127.0.0.1:0")
err := srv.Start(t.Context(), addr, netip.Prefix{})
require.Error(t, err, "Start must refuse an invalid overlay prefix")
@@ -306,7 +210,7 @@ func TestStart_InvalidNetworkRejected(t *testing.T) {
}
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(true)
srv.SetAgentToken("deadbeefcafebabe")
@@ -334,7 +238,7 @@ func TestAgentToken_MismatchClosesConnection(t *testing.T) {
}
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(true)
const tokenHex = "deadbeefcafebabe"
srv.SetAgentToken(tokenHex)
@@ -356,7 +260,7 @@ func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
require.NoError(t, err)
// Send session header so handleConnection can proceed past readConnectionHeader.
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
header := make([]byte, 11) // ModeAttach + usernameLen=0 + sessionID=0 + width=0 + height=0
header[0] = ModeAttach
_, err = conn.Write(header)
require.NoError(t, err)
@@ -371,7 +275,7 @@ func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
// must be rejected with the UNSUPPORTED reason rather than crashing.
srv := New(&testCapturer{}, &StubInputInjector{})
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
srv.SetDisableAuth(true)
addr := netip.MustParseAddrPort("127.0.0.1:0")
@@ -387,7 +291,7 @@ func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
defer conn.Close()
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
// ModeSession with no username/JWT, so we exit on the vmgr==nil branch
// ModeSession with no username, so we exit on the vmgr==nil branch
// before username validation runs.
header := []byte{ModeSession, 0, 0, 0, 0}
_, err = conn.Write(header)

View File

@@ -233,8 +233,9 @@ func (s *Server) platformInit() {
startSASListener(s.ctx)
}
// serviceAcceptLoop runs in Session 0. It validates source IP and
// authenticates via JWT before proxying connections to the user-session agent.
// serviceAcceptLoop runs in Session 0. It validates the source IP and
// hands accepted connections to handleServiceConnection, which runs the
// Noise_IK handshake before proxying to the user-session agent.
func (s *Server) serviceAcceptLoop() {
sm := newSessionManager(agentPort)
@@ -255,18 +256,25 @@ func (s *Server) serviceAcceptLoop() {
continue
}
if !s.tryAcquireConnSlot() {
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
_ = conn.Close()
continue
}
enableTCPKeepAlive(conn, s.log)
conn = newMetricsConn(conn, s.sessionRecorder)
s.trackConn(conn)
go func(c net.Conn) {
defer s.releaseConnSlot()
defer s.untrackConn(c)
s.handleServiceConnection(c, sm)
}(conn)
}
}
// handleServiceConnection validates the source IP and JWT, then proxies
// the connection (with header bytes replayed) to the agent.
// handleServiceConnection runs the connection-header handshake (including
// Noise_IK), then proxies the connection (with header bytes replayed) to
// the agent listening on loopback.
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
@@ -279,7 +287,7 @@ func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := readConnectionHeader(teeConn)
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Debugf("read connection header: %v", err)
conn.Close()
@@ -287,17 +295,13 @@ func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return
}
if _, err := s.authenticateJWT(header); err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
if _, err := s.authenticateSession(header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
connLog.Warnf("auth rejected: %v", err)
return
}
}
s.registerConnAuth(conn, header)
// Replay buffered header bytes + remaining stream to the agent.
replayConn := &prefixConn{

View File

@@ -222,9 +222,9 @@ func (s *session) handshake() error {
}
// sendSecurityTypes advertises only secNone. Authentication and access
// control happen in the NetBird connection header (JWT, mode, username)
// that precedes the RFB handshake, not via the protocol-level password
// scheme.
// control happen in the NetBird connection header (Noise_IK handshake,
// mode, username) that precedes the RFB handshake; the protocol-level
// password scheme is not supported.
func (s *session) sendSecurityTypes() error {
_, err := s.conn.Write([]byte{1, secNone})
return err

View File

@@ -225,6 +225,10 @@ func (s *session) handleResize() error {
if w <= 0 || h <= 0 {
return nil
}
if w > maxFramebufferDim || h > maxFramebufferDim {
s.log.Warnf("ignoring resize: %dx%d exceeds cap %d", w, h, maxFramebufferDim)
return nil
}
if w == s.serverW && h == s.serverH {
return nil
}

View File

@@ -4,6 +4,7 @@ package main
import (
"context"
"encoding/base64"
"fmt"
"net"
"strconv"
@@ -39,6 +40,7 @@ const (
func main() {
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
js.Global().Set("netbirdGenerateVNCSessionKey", createGenerateVNCSessionKeyMethod())
select {}
}
@@ -388,13 +390,31 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
})
}
// createGenerateVNCSessionKeyMethod returns a JS func that mints a fresh
// X25519 keypair, stashes the private half inside wasm under a random
// session id, and returns { publicKey, sessionId } to JS. The private
// key never leaves the wasm heap.
func createGenerateVNCSessionKeyMethod() js.Func {
return js.FuncOf(func(_ js.Value, _ []js.Value) any {
id, pub, err := vnc.NewSessionKey()
if err != nil {
return js.ValueOf(err.Error())
}
out := js.Global().Get("Object").New()
out.Set("sessionId", id)
out.Set("publicKey", base64.StdEncoding.EncodeToString(pub))
return out
})
}
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?, width?, height?)
// mode: "attach" (default) or "session"
// username: required when mode is "session"
// jwt: authentication token (from OIDC session)
// sessionID: Windows session ID (0 = console/auto)
// width/height: requested viewport size for session mode (0 = server default)
// JS signature: createVNCProxy(hostname, port, mode?, username?, keySessionID?, sessionID?, width?, height?, peerPublicKey?)
// mode: "attach" (default) or "session"
// username: required when mode is "session"
// keySessionID: handle for the wasm-resident session keypair minted by netbirdGenerateVNCSessionKey
// sessionID: Windows session ID (0 = console/auto)
// width/height: requested viewport size for session mode (0 = server default)
// peerPublicKey: base64 X25519 static pubkey of the destination peer (required for auth)
func createVNCProxyMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
params, err := parseVNCProxyArgs(args)
@@ -408,14 +428,15 @@ func createVNCProxyMethod(client *netbird.Client) js.Func {
}
proxy := vnc.NewVNCProxy(client)
return proxy.CreateProxy(vnc.ProxyRequest{
Hostname: params.hostname,
Port: params.port,
Mode: params.mode,
Username: params.username,
JWT: params.jwt,
SessionID: params.sessionID,
Width: params.width,
Height: params.height,
Hostname: params.hostname,
Port: params.port,
Mode: params.mode,
Username: params.username,
SessionID: params.sessionID,
Width: params.width,
Height: params.height,
PeerPublicKey: params.peerPublicKey,
KeySessionID: params.keySessionID,
})
})
}
@@ -425,11 +446,12 @@ type vncProxyParams struct {
port string
mode string
username string
jwt string
keySessionID string
sessionID uint32
width uint16
height uint16
rejectViaPromise bool // true when the JS caller expects a rejected Promise instead of a plain string return
peerPublicKey string
rejectViaPromise bool
}
// parseVNCProxyArgs validates JS args for createVNCProxyMethod and returns
@@ -480,7 +502,7 @@ func parseVNCProxyOptionalStrings(args []js.Value, p *vncProxyParams) error {
p.username = args[3].String()
}
if len(args) > 4 && args[4].Type() == js.TypeString {
p.jwt = args[4].String()
p.keySessionID = args[4].String()
}
return nil
}
@@ -512,6 +534,9 @@ func parseVNCProxyOptionalNumbers(args []js.Value, p *vncProxyParams) error {
}
p.height = uint16(v)
}
if len(args) > 8 && args[8].Type() == js.TypeString {
p.peerPublicKey = args[8].String()
}
return nil
}

View File

@@ -4,6 +4,8 @@ package vnc
import (
"context"
crand "crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -13,9 +15,65 @@ import (
"syscall/js"
"time"
"github.com/flynn/noise"
log "github.com/sirupsen/logrus"
)
var cryptoRandRead = crand.Read
// vncIdentityMagic mirrors the server side in client/vnc/server/server.go.
var vncIdentityMagic = []byte("NBV3")
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
const (
noiseInitiatorMsgLen = 96
noiseResponderMsgLen = 48
)
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
// sessionKeyStore retains per-session X25519 keypairs so the JS layer
// only sees an opaque session id + the public key; the private key never
// leaves wasm.
var sessionKeyStore = struct {
mu sync.Mutex
keys map[string]noise.DHKey
}{keys: map[string]noise.DHKey{}}
// NewSessionKey mints an X25519 keypair, stores the private half under a
// fresh random session id, and returns (id, pubkey).
func NewSessionKey() (string, []byte, error) {
kp, err := noise.DH25519.GenerateKeypair(nil)
if err != nil {
return "", nil, fmt.Errorf("generate keypair: %w", err)
}
idBytes := make([]byte, 16)
if _, err := cryptoRandRead(idBytes); err != nil {
return "", nil, fmt.Errorf("session id randomness: %w", err)
}
id := base64.RawURLEncoding.EncodeToString(idBytes)
sessionKeyStore.mu.Lock()
sessionKeyStore.keys[id] = kp
sessionKeyStore.mu.Unlock()
return id, kp.Public, nil
}
// lookupSessionKey returns the keypair for id, or false if unknown.
func lookupSessionKey(id string) (noise.DHKey, bool) {
sessionKeyStore.mu.Lock()
defer sessionKeyStore.mu.Unlock()
kp, ok := sessionKeyStore.keys[id]
return kp, ok
}
// dropSessionKey removes the keypair for id. Called after the VNC
// connection closes (or after a connect attempt fails terminally).
func dropSessionKey(id string) {
sessionKeyStore.mu.Lock()
delete(sessionKeyStore.keys, id)
sessionKeyStore.mu.Unlock()
}
const (
vncProxyHost = "vnc.proxy.local"
vncProxyScheme = "ws"
@@ -37,10 +95,12 @@ const (
// VNCProxy bridges WebSocket connections from noVNC in the browser
// to TCP VNC server connections through the NetBird tunnel.
type vncNBClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
type VNCProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
nbClient vncNBClient
activeConnections map[string]*vncConnection
destinations map[string]vncDestination
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
@@ -52,13 +112,15 @@ type VNCProxy struct {
}
type vncDestination struct {
address string
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = auto/console)
width uint16 // Requested viewport width for session mode (0 = default)
height uint16 // Requested viewport height for session mode (0 = default)
address string
mode byte
username string
sessionPriv []byte
sessionPub []byte
sessionID uint32
width uint16
height uint16
peerPubKey []byte
}
type vncConnection struct {
@@ -78,9 +140,7 @@ type vncConnection struct {
}
// NewVNCProxy creates a new VNC proxy.
func NewVNCProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *VNCProxy {
func NewVNCProxy(client vncNBClient) *VNCProxy {
return &VNCProxy{
nbClient: client,
activeConnections: make(map[string]*vncConnection),
@@ -94,10 +154,16 @@ type ProxyRequest struct {
Port string
Mode string
Username string
JWT string
SessionID uint32
Width uint16
Height uint16
// PeerPublicKey is the destination peer's base64 X25519 public key,
// used as the responder static in the Noise_IK handshake.
PeerPublicKey string
// KeySessionID is the handle returned by generateVNCSessionKey. The
// matching private key is looked up inside wasm and never crosses
// the JS boundary.
KeySessionID string
}
// CreateProxy creates a new proxy endpoint for the given VNC destination.
@@ -106,7 +172,7 @@ type ProxyRequest struct {
// virtual display geometry for session mode; 0 means use the server default.
// Returns a JS Promise that resolves to the WebSocket proxy URL.
func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
hostname, port, mode, username, jwt := req.Hostname, req.Port, req.Mode, req.Username, req.JWT
hostname, port, mode, username := req.Hostname, req.Port, req.Mode, req.Username
sessionID, width, height := req.SessionID, req.Width, req.Height
address := net.JoinHostPort(hostname, port)
@@ -119,14 +185,51 @@ func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
address: address,
mode: m,
username: username,
jwt: jwt,
sessionID: sessionID,
width: width,
height: height,
}
if req.KeySessionID != "" {
kp, ok := lookupSessionKey(req.KeySessionID)
if !ok {
return rejectedPromise("unknown VNC session id")
}
// A session handle is single-use; drop it before the destination
// holds the private bytes so a leaked handle can't be replayed.
dropSessionKey(req.KeySessionID)
dest.sessionPriv = kp.Private
dest.sessionPub = kp.Public
pub, err := decodePeerPubKey(req.PeerPublicKey)
if err != nil {
return rejectedPromise(fmt.Sprintf("invalid peer public key: %v", err))
}
dest.peerPubKey = pub
}
return p.newProxyPromise(address, mode, username, dest)
}
// decodePeerPubKey parses a base64-encoded 32-byte X25519 public key.
func decodePeerPubKey(b64 string) ([]byte, error) {
if b64 == "" {
return nil, errors.New("peer public key missing")
}
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
if len(raw) != 32 {
return nil, fmt.Errorf("expected 32 bytes, got %d", len(raw))
}
return raw, nil
}
// rejectedPromise returns a resolved Promise carrying msg as an error
// string, mirroring how CreateProxy reports earlier validation failures.
func rejectedPromise(msg string) js.Value {
promise := js.Global().Get("Promise")
return promise.Call("resolve", js.ValueOf(msg))
}
// newProxyPromise wraps the JS Promise creation + executor lifecycle so
// CreateProxy stays a thin parameter-bundling entrypoint.
func (p *VNCProxy) newProxyPromise(address, mode, username string, dest vncDestination) js.Value {
@@ -288,46 +391,95 @@ func (p *VNCProxy) connectToVNC(conn *vncConnection) {
p.cleanupConnection(conn)
}
// sendSessionHeader writes mode, username, JWT, Windows session ID, and the
// requested viewport size to the VNC server.
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N]
//
// [session_id:4] [width:2] [height:2]
// sendSessionHeader writes the NetBird VNC connection header: mode +
// username prefix, an optional Noise_IK handshake that authenticates the
// client and the server, then the trailing sessionID / width / height
// fields the daemon needs once auth is settled.
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
usernameBytes := []byte(dest.username)
jwtBytes := []byte(dest.jwt)
if len(usernameBytes) > 0xFFFF {
return fmt.Errorf("username too long: %d bytes (max %d)", len(usernameBytes), 0xFFFF)
}
if len(jwtBytes) > 0xFFFF {
return fmt.Errorf("jwt too long: %d bytes (max %d)", len(jwtBytes), 0xFFFF)
prefix := make([]byte, 3+len(usernameBytes))
prefix[0] = dest.mode
prefix[1] = byte(len(usernameBytes) >> 8)
prefix[2] = byte(len(usernameBytes))
copy(prefix[3:], usernameBytes)
if err := writeAll(conn, prefix); err != nil {
return fmt.Errorf("write header prefix: %w", err)
}
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4+4)
hdr[0] = dest.mode
hdr[1] = byte(len(usernameBytes) >> 8)
hdr[2] = byte(len(usernameBytes))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
hdr[off] = byte(len(jwtBytes) >> 8)
hdr[off+1] = byte(len(jwtBytes))
off += 2
copy(hdr[off:], jwtBytes)
off += len(jwtBytes)
hdr[off] = byte(dest.sessionID >> 24)
hdr[off+1] = byte(dest.sessionID >> 16)
hdr[off+2] = byte(dest.sessionID >> 8)
hdr[off+3] = byte(dest.sessionID)
off += 4
hdr[off] = byte(dest.width >> 8)
hdr[off+1] = byte(dest.width)
hdr[off+2] = byte(dest.height >> 8)
hdr[off+3] = byte(dest.height)
for off := 0; off < len(hdr); {
n, err := conn.Write(hdr[off:])
if dest.sessionPriv == nil {
return p.writeHeaderTail(conn, dest)
}
if err := p.runNoiseHandshake(conn, dest); err != nil {
return fmt.Errorf("noise handshake: %w", err)
}
return p.writeHeaderTail(conn, dest)
}
// writeHeaderTail writes the post-auth trailing fields (sessionID,
// width, height) the daemon reads regardless of whether the Noise
// handshake was performed.
func (p *VNCProxy) writeHeaderTail(conn net.Conn, dest vncDestination) error {
tail := make([]byte, 4+4)
tail[0] = byte(dest.sessionID >> 24)
tail[1] = byte(dest.sessionID >> 16)
tail[2] = byte(dest.sessionID >> 8)
tail[3] = byte(dest.sessionID)
tail[4] = byte(dest.width >> 8)
tail[5] = byte(dest.width)
tail[6] = byte(dest.height >> 8)
tail[7] = byte(dest.height)
if err := writeAll(conn, tail); err != nil {
return fmt.Errorf("write header tail: %w", err)
}
return nil
}
// runNoiseHandshake performs the initiator side of a Noise_IK handshake
// against the destination daemon. The session keypair authenticates the
// client; the daemon's pre-known peer pubkey authenticates the server.
func (p *VNCProxy) runNoiseHandshake(conn net.Conn, dest vncDestination) error {
state, err := noise.NewHandshakeState(noise.Config{
CipherSuite: vncNoiseSuite,
Pattern: noise.HandshakeIK,
Initiator: true,
StaticKeypair: noise.DHKey{Private: dest.sessionPriv, Public: dest.sessionPub},
PeerStatic: dest.peerPubKey,
})
if err != nil {
return fmt.Errorf("noise initiator init: %w", err)
}
msg1, _, _, err := state.WriteMessage(nil, nil)
if err != nil {
return fmt.Errorf("noise write msg1: %w", err)
}
out := make([]byte, 0, len(vncIdentityMagic)+len(msg1))
out = append(out, vncIdentityMagic...)
out = append(out, msg1...)
if err := writeAll(conn, out); err != nil {
return fmt.Errorf("send noise msg1: %w", err)
}
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("set noise deadline: %w", err)
}
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
msg2 := make([]byte, noiseResponderMsgLen)
if _, err := io.ReadFull(conn, msg2); err != nil {
return fmt.Errorf("read noise msg2: %w", err)
}
if _, _, _, err := state.ReadMessage(nil, msg2); err != nil {
return fmt.Errorf("noise read msg2: %w", err)
}
return nil
}
func writeAll(conn net.Conn, buf []byte) error {
for off := 0; off < len(buf); {
n, err := conn.Write(buf[off:])
if err != nil {
return fmt.Errorf("write session header: %w", err)
return err
}
off += n
}

1
go.mod
View File

@@ -185,6 +185,7 @@ require (
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect

4
go.sum
View File

@@ -162,6 +162,8 @@ github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
@@ -412,6 +414,7 @@ github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoK
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
@@ -756,6 +759,7 @@ goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=

View File

@@ -2,6 +2,7 @@ package grpc
import (
"context"
"encoding/base64"
"fmt"
"net/netip"
"net/url"
@@ -184,12 +185,47 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
if networkMap.VNCAuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
response.NetworkMap.VncAuth = &proto.VNCAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
response.NetworkMap.VncAuth = &proto.VNCAuth{
AuthorizedUsers: hashedUsers,
MachineUsers: machineUsers,
SessionPubKeys: buildSessionPubKeysProto(ctx, networkMap.VNCSessionPubKeys),
}
}
return response
}
// buildSessionPubKeysProto decodes base64 X25519 session pubkeys and
// hashes the user IDs they belong to, emitting the proto entries the
// daemon's authorizer indexes by pubkey.
func buildSessionPubKeysProto(ctx context.Context, in []types.VNCSessionPubKey) []*proto.SessionPubKey {
if len(in) == 0 {
return nil
}
out := make([]*proto.SessionPubKey, 0, len(in))
for _, e := range in {
pub, err := base64.StdEncoding.DecodeString(e.PubKey)
if err != nil {
log.WithContext(ctx).Warnf("decode VNC session pubkey: %v", err)
continue
}
if len(pub) != 32 {
log.WithContext(ctx).Warnf("VNC session pubkey wrong length: %d", len(pub))
continue
}
hash, err := sshauth.HashUserID(e.UserID)
if err != nil {
log.WithContext(ctx).Warnf("hash VNC session user id: %v", err)
continue
}
out = append(out, &proto.SessionPubKey{
PubKey: pub,
UserIdHash: hash[:],
})
}
return out
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte

View File

@@ -517,6 +517,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
if protocol == types.PolicyRuleProtocolNetbirdSSH || protocol == types.PolicyRuleProtocolNetbirdVNC {
policy.Rules[0].AuthorizedUser = userAuth.UserId
}
if protocol == types.PolicyRuleProtocolNetbirdVNC && req.SessionPubKey != nil {
policy.Rules[0].SessionPubKey = *req.SessionPubKey
}
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
@@ -526,9 +529,10 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
}
resp := &api.PeerTemporaryAccessResponse{
Id: peer.ID,
Name: peer.Name,
Rules: req.Rules,
Id: peer.ID,
Name: peer.Name,
Rules: req.Rules,
TargetPubKey: targetPeer.Key,
}
util.WriteJSONObject(r.Context(), w, resp)

View File

@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
})
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -849,7 +849,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, map[string]map[string]struct{}, bool) {
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, map[string]map[string]struct{}, []VNCSessionPubKey, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
ctxState := &peerConnResolveState{
authorizedUsers: make(map[string]map[string]struct{}),
@@ -869,7 +869,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
}
peers, fwRules := getAccumulatedResources()
return peers, fwRules, ctxState.authorizedUsers, ctxState.vncAuthorizedUsers, ctxState.sshEnabled
return peers, fwRules, ctxState.authorizedUsers, ctxState.vncAuthorizedUsers, ctxState.vncSessionPubKeys, ctxState.sshEnabled
}
func (a *Account) applyPolicyRule(

View File

@@ -49,6 +49,7 @@ type NetworkMap struct {
ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
VNCAuthorizedUsers map[string]map[string]struct{}
VNCSessionPubKeys []VNCSessionPubKey
EnableSSH bool
}

View File

@@ -167,6 +167,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
AuthorizedUsers: connRes.authorizedUsers,
VNCAuthorizedUsers: connRes.vncAuthorizedUsers,
VNCSessionPubKeys: connRes.vncSessionPubKeys,
EnableSSH: connRes.sshEnabled,
}
}
@@ -177,6 +178,7 @@ type peerConnectionResult struct {
firewallRules []*FirewallRule
authorizedUsers map[string]map[string]struct{}
vncAuthorizedUsers map[string]map[string]struct{}
vncSessionPubKeys []VNCSessionPubKey
sshEnabled bool
}
@@ -210,6 +212,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) p
firewallRules: fwRules,
authorizedUsers: state.authorizedUsers,
vncAuthorizedUsers: state.vncAuthorizedUsers,
vncSessionPubKeys: state.vncSessionPubKeys,
sshEnabled: state.sshEnabled,
}
}

View File

@@ -15,9 +15,21 @@ import (
type peerConnResolveState struct {
authorizedUsers map[string]map[string]struct{}
vncAuthorizedUsers map[string]map[string]struct{}
vncSessionPubKeys []VNCSessionPubKey
sshEnabled bool
}
// VNCSessionPubKey carries an ephemeral X25519 static public key the
// dashboard registered via temporary-access. The daemon uses it as the
// allowed-client side of a Noise_IK handshake; a successful handshake
// authenticates the connection as UserID.
type VNCSessionPubKey struct {
// PubKey is the base64-encoded 32-byte X25519 public key.
PubKey string
// UserID is the unhashed user identity the pubkey authenticates as.
UserID string
}
// ruleAuthCallbacks lets Account and NetworkMapComponents share the per-rule
// direction-and-auth logic while keeping their own context/state plumbing for
// authorized-user collection and allowed-user lookups.
@@ -57,6 +69,12 @@ func applyResolvedRuleToState(
return
}
cb.collectVNCUsers(rule, state.vncAuthorizedUsers)
if rule.SessionPubKey != "" && rule.AuthorizedUser != "" {
state.vncSessionPubKeys = append(state.vncSessionPubKeys, VNCSessionPubKey{
PubKey: rule.SessionPubKey,
UserID: rule.AuthorizedUser,
})
}
case policyRuleImpliesLegacySSH(rule) && targetPeerSSHEnabled:
if !peerInDestinations {
return

View File

@@ -88,6 +88,12 @@ type PolicyRule struct {
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
AuthorizedUser string
// SessionPubKey is the base64 Ed25519 public key the AuthorizedUser
// will sign session-binding challenges with. Set together with
// AuthorizedUser when the rule was created via temporary-access for
// a VNC scope; empty otherwise.
SessionPubKey string
}
// Copy returns a copy of a policy rule
@@ -109,6 +115,7 @@ func (pm *PolicyRule) Copy() *PolicyRule {
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
AuthorizedUser: pm.AuthorizedUser,
SessionPubKey: pm.SessionPubKey,
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
@@ -136,7 +143,8 @@ func (pm *PolicyRule) Equal(other *PolicyRule) bool {
pm.Protocol != other.Protocol ||
pm.SourceResource != other.SourceResource ||
pm.DestinationResource != other.DestinationResource ||
pm.AuthorizedUser != other.AuthorizedUser {
pm.AuthorizedUser != other.AuthorizedUser ||
pm.SessionPubKey != other.SessionPubKey {
return false
}

View File

@@ -1,68 +0,0 @@
package jwt
import (
"errors"
"fmt"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
)
// ErrTokenExpired signals that the iat-based token age check failed. Callers
// use errors.Is to branch on it when they want to surface a stable machine-
// readable reason (e.g. so a dashboard can prompt for re-login).
var ErrTokenExpired = errors.New("token expired")
// CheckTokenAge validates that a JWT token's iat claim is within the given
// maxAge duration. Returns an error if the claims are unparsable, the iat
// claim is missing, or the token is too old.
func CheckTokenAge(token *gojwt.Token, maxAge time.Duration) error {
if token == nil {
return fmt.Errorf("token is nil")
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return fmt.Errorf("token has invalid claims format (user=%s)", UserIDFromToken(token))
}
iat, ok := claims["iat"].(float64)
if !ok {
return fmt.Errorf("token missing iat claim (user=%s)", UserIDFromToken(token))
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
if tokenAge > maxAge {
return fmt.Errorf("%w for user=%s: age=%v, max=%v", ErrTokenExpired, userIDFromClaims(claims), tokenAge, maxAge)
}
return nil
}
// UserIDFromToken extracts a human-readable user identifier from a JWT token
// for use in error messages. Returns "unknown" if the token or claims are nil.
func UserIDFromToken(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return userIDFromClaims(claims)
}
// userIDFromClaims extracts a user identifier from JWT claims, trying sub,
// user_id, and email in order.
func userIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}

View File

@@ -1007,6 +1007,10 @@ components:
items:
type: string
example: "tcp/80"
session_pub_key:
description: Ephemeral Ed25519 public key the requester will sign session-binding challenges with. Required for VNC rules; ignored for SSH and L4.
type: string
example: "n0r3pL4c3h0ld3rK3y=="
required:
- name
- wg_pub_key
@@ -1028,10 +1032,15 @@ components:
items:
type: string
example: "tcp/80"
target_pub_key:
description: Identity public key of the destination peer the temporary access was requested for. Used by the requester to verify the destination daemon's identity before transmitting credentials.
type: string
example: "n0r3pL4c3h0ld3rK3y=="
required:
- name
- id
- rules
- target_pub_key
AccessiblePeer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'

View File

@@ -3391,6 +3391,9 @@ type PeerTemporaryAccessRequest struct {
// Rules List of temporary access rules
Rules []string `json:"rules"`
// SessionPubKey Ephemeral Ed25519 public key the requester will sign session-binding challenges with. Required for VNC rules; ignored for SSH and L4.
SessionPubKey *string `json:"session_pub_key,omitempty"`
// WgPubKey Peer's WireGuard public key
WgPubKey string `json:"wg_pub_key"`
}
@@ -3405,6 +3408,9 @@ type PeerTemporaryAccessResponse struct {
// Rules List of temporary access rules
Rules []string `json:"rules"`
// TargetPubKey Identity public key of the destination peer the temporary access was requested for. Used by the requester to verify the destination daemon's identity before transmitting credentials.
TargetPubKey string `json:"target_pub_key"`
}
// PersonalAccessToken defines model for PersonalAccessToken.

File diff suppressed because it is too large Load Diff

View File

@@ -428,9 +428,6 @@ message MachineUserIndexes {
// VNCAuth represents VNC authorization configuration for a peer.
message VNCAuth {
// UserIDClaim is the JWT claim to be used to get the users ID
string UserIDClaim = 1;
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via VNC
repeated bytes AuthorizedUsers = 2;
@@ -438,6 +435,24 @@ message VNCAuth {
// Used in session mode to determine which OS user to create the virtual session as.
// The wildcard "*" allows any OS user.
map<string, MachineUserIndexes> machine_users = 3;
// SessionPubKeys are short-lived X25519 static keypairs the dashboard
// (or other temporary-access clients) registers per session. The
// daemon runs a Noise_IK handshake against the matching pubkey to
// authenticate the connection and resolve the pubkey back to a user.
repeated SessionPubKey session_pub_keys = 4;
}
// SessionPubKey binds an ephemeral X25519 static public key to a hashed
// user identity so the daemon can authorize VNC connections that
// complete a Noise_IK handshake with the matching private key.
message SessionPubKey {
// PubKey is the 32-byte X25519 static public key.
bytes pub_key = 1;
// UserIDHash is the BLAKE2b-128 hash of the user ID this session
// belongs to, matching the entries in VNCAuth.AuthorizedUsers.
bytes user_id_hash = 2;
}
// RemotePeerConfig represents a configuration of a remote peer.