mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 04:59:54 +00:00
Replace VNC JWT auth with a Noise_IK handshake bound to ACL-pushed pubkeys
This commit is contained in:
@@ -39,12 +39,22 @@ var vncAgentCmd = &cobra.Command{
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Drop the token from our process environment so any child the
|
||||
// agent spawns does not inherit it, and casual debugging tools
|
||||
// that dump /proc/<pid>/environ (or the Windows equivalent) on a
|
||||
// running agent don't surface the loopback shared secret.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv := vncserver.New(capturer, injector)
|
||||
// The per-user agent listens only on loopback and is gated by an
|
||||
// agent token shared with the daemon, so no X25519 identity key
|
||||
// is needed; auth is disabled at the RFB layer.
|
||||
srv := vncserver.New(capturer, injector, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(token)
|
||||
|
||||
|
||||
@@ -1064,7 +1064,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -66,8 +66,7 @@ func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
@@ -81,15 +80,13 @@ func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
// Update JWT config on existing server in case management sent new config.
|
||||
e.updateVNCServerJWT(sshConf)
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer(sshConf)
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
@@ -102,7 +99,7 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector)
|
||||
srv := vncserver.New(capturer, injector, e.config.WgPrivateKey[:])
|
||||
if e.clientMetrics != nil {
|
||||
srv.SetSessionRecorder(func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
@@ -122,20 +119,6 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
@@ -165,35 +148,6 @@ func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||
// the same JWT config as SSH (same identity provider).
|
||||
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoJWT := sshConf.GetJwtConfig()
|
||||
if protoJWT == nil {
|
||||
return
|
||||
}
|
||||
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
@@ -221,10 +175,28 @@ func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, e := range vncAuth.GetSessionPubKeys() {
|
||||
pub := e.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := e.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC(_ *mgmProto.SSHConfig) error { return nil }
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(_ *mgmProto.VNCAuth) {
|
||||
// no-op on platforms without a VNC server
|
||||
|
||||
@@ -2107,7 +2107,9 @@ type VNCSessionInfo struct {
|
||||
RemoteAddress string `protobuf:"bytes,1,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
|
||||
Mode string `protobuf:"bytes,2,opt,name=mode,proto3" json:"mode,omitempty"`
|
||||
Username string `protobuf:"bytes,3,opt,name=username,proto3" json:"username,omitempty"`
|
||||
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
UserID string `protobuf:"bytes,4,opt,name=userID,proto3" json:"userID,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -2163,9 +2165,9 @@ func (x *VNCSessionInfo) GetUsername() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *VNCSessionInfo) GetJwtUsername() string {
|
||||
func (x *VNCSessionInfo) GetUserID() string {
|
||||
if x != nil {
|
||||
return x.JwtUsername
|
||||
return x.UserID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -6582,12 +6584,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
|
||||
"\x0eSSHServerState\x12\x18\n" +
|
||||
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
|
||||
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\x88\x01\n" +
|
||||
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"~\n" +
|
||||
"\x0eVNCSessionInfo\x12$\n" +
|
||||
"\rremoteAddress\x18\x01 \x01(\tR\rremoteAddress\x12\x12\n" +
|
||||
"\x04mode\x18\x02 \x01(\tR\x04mode\x12\x1a\n" +
|
||||
"\busername\x18\x03 \x01(\tR\busername\x12 \n" +
|
||||
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
|
||||
"\busername\x18\x03 \x01(\tR\busername\x12\x16\n" +
|
||||
"\x06userID\x18\x04 \x01(\tR\x06userID\"^\n" +
|
||||
"\x0eVNCServerState\x12\x18\n" +
|
||||
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
|
||||
"\bsessions\x18\x02 \x03(\v2\x16.daemon.VNCSessionInfoR\bsessions\"\xef\x04\n" +
|
||||
|
||||
@@ -403,7 +403,9 @@ message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
string jwtUsername = 4;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
|
||||
@@ -1199,7 +1199,7 @@ func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
JwtUsername: sess.JWTUsername,
|
||||
UserID: sess.UserID,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
|
||||
@@ -15,13 +15,16 @@ const (
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
|
||||
sessionPubKeyLen = 32
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
@@ -35,6 +38,12 @@ type Authorizer struct {
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// sessionPubKeys maps an X25519 static public key (as map-safe
|
||||
// array) to the hashed user identity that key authenticates as.
|
||||
// Populated from management's temporary-access flow; used by VNC to
|
||||
// authenticate via the Noise_IK handshake.
|
||||
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -50,13 +59,25 @@ type Config struct {
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
|
||||
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
|
||||
// user identities. Populated for VNC; ignored on the SSH side.
|
||||
SessionPubKeys []SessionPubKey
|
||||
}
|
||||
|
||||
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
|
||||
// static public key plus the hashed user identity it authenticates as.
|
||||
type SessionPubKey struct {
|
||||
PubKey []byte
|
||||
UserIDHash sshuserhash.UserIDHash
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
|
||||
}
|
||||
|
||||
return a
|
||||
@@ -72,6 +93,7 @@ func (a *Authorizer) Update(config *Config) {
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
@@ -94,8 +116,19 @@ func (a *Authorizer) Update(config *Config) {
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
|
||||
for _, e := range config.SessionPubKeys {
|
||||
if len(e.PubKey) != sessionPubKeyLen {
|
||||
continue
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], e.PubKey)
|
||||
sessionPubKeys[key] = e.UserIDHash
|
||||
}
|
||||
a.sessionPubKeys = sessionPubKeys
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
|
||||
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||
@@ -155,6 +188,38 @@ func (a *Authorizer) GetUserIDClaim() string {
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// LookupSessionKey resolves a Noise-verified static public key to the
|
||||
// hashed user identity registered with it. Fails closed when the key is
|
||||
// unknown.
|
||||
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
|
||||
var zero sshuserhash.UserIDHash
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
hash, ok := a.sessionPubKeys[key]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return zero, ErrSessionKeyNotKnown
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
|
||||
// key. Mirrors Authorize but skips the JWT-hash step since the key has
|
||||
// already been verified and the user identity hash is in hand.
|
||||
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
userIndex, found := a.findUserIndex(userIDHash)
|
||||
if !found {
|
||||
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
|
||||
}
|
||||
return a.checkMachineUserMapping("session", osUsername, userIndex)
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -610,3 +611,61 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
|
||||
pub := bytesRepeat(0x11, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{
|
||||
AuthorizedUsers: []sshauth.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{Wildcard: {0}},
|
||||
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
|
||||
})
|
||||
|
||||
got, err := a.LookupSessionKey(pub)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userHash, got)
|
||||
|
||||
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
|
||||
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{})
|
||||
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
|
||||
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
_, err := a.LookupSessionKey([]byte("short"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
|
||||
pub := bytesRepeat(0x33, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
|
||||
if _, err := a.LookupSessionKey(pub); err != nil {
|
||||
t.Fatalf("setup lookup: %v", err)
|
||||
}
|
||||
a.Update(&Config{})
|
||||
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
|
||||
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesRepeat(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication.
|
||||
// This is the same approach OpenSSH for Windows uses for public key authentication.
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
|
||||
@@ -551,7 +551,27 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
maxTokenAge = DefaultJWTMaxTokenAge
|
||||
}
|
||||
|
||||
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
@@ -582,7 +602,27 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
return jwt.UserIDFromToken(token)
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
|
||||
@@ -135,7 +135,7 @@ type VNCSessionOutput struct {
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
|
||||
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
@@ -296,7 +296,7 @@ func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
|
||||
RemoteAddress: sess.GetRemoteAddress(),
|
||||
Mode: sess.GetMode(),
|
||||
Username: sess.GetUsername(),
|
||||
JWTUsername: sess.GetJwtUsername(),
|
||||
UserID: sess.GetUserID(),
|
||||
})
|
||||
}
|
||||
return VNCServerStateOutput{
|
||||
@@ -583,9 +583,9 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
if showSSHSessions && vncSessionCount > 0 {
|
||||
for _, sess := range o.VNCServerState.Sessions {
|
||||
var line string
|
||||
if sess.JWTUsername != "" {
|
||||
if sess.UserID != "" {
|
||||
line = fmt.Sprintf("[%s@%s -> %s] mode=%s",
|
||||
sess.JWTUsername, sess.RemoteAddress, sess.Username, sess.Mode)
|
||||
sess.UserID, sess.RemoteAddress, sess.Username, sess.Mode)
|
||||
} else {
|
||||
line = fmt.Sprintf("[%s] mode=%s user=%s",
|
||||
sess.RemoteAddress, sess.Mode, sess.Username)
|
||||
|
||||
431
client/vnc/server/noise_auth_test.go
Normal file
431
client/vnc/server/noise_auth_test.go
Normal file
@@ -0,0 +1,431 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
// noiseTestServer starts a VNC server with a freshly generated identity
|
||||
// key and returns the listener address, the server, and the server's
|
||||
// static public key for client-side handshake setup.
|
||||
func noiseTestServer(t *testing.T) (net.Addr, *Server, []byte) {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, kp.Private)
|
||||
srv.SetDisableAuth(false)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv, kp.Public
|
||||
}
|
||||
|
||||
// registerSessionKey enrolls a fresh X25519 keypair under the given user
|
||||
// ID into the server's authorizer with the requested OS-user wildcard
|
||||
// mapping. Returns the keypair so the test can drive the handshake.
|
||||
func registerSessionKey(t *testing.T, srv *Server, userID string) noise.DHKey {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userHash, err := sshuserhash.HashUserID(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{sshauth.Wildcard: {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: kp.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
return kp
|
||||
}
|
||||
|
||||
// writeHeaderPrefix writes the mode + zero-length-username prefix that
|
||||
// precedes the optional Noise handshake in the NetBird VNC header.
|
||||
func writeHeaderPrefix(t *testing.T, conn net.Conn, mode byte) {
|
||||
t.Helper()
|
||||
prefix := []byte{mode, 0, 0}
|
||||
_, err := conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the sessionID/width/height fields that follow
|
||||
// either the Noise msg2 (auth path) or the prefix alone (no-auth path).
|
||||
func writeHeaderTail(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
tail := make([]byte, 8)
|
||||
_, err := conn.Write(tail)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// performInitiator drives the initiator side of Noise_IK against the
|
||||
// server's identity public key, returns the resulting state. The Noise
|
||||
// msg2 produced by the server is read and consumed.
|
||||
func performInitiator(t *testing.T, conn net.Conn, clientKey noise.DHKey, serverPub []byte) {
|
||||
t.Helper()
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: serverPub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, noiseInitiatorMsgLen, len(msg1))
|
||||
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
_, err = io.ReadFull(conn, msg2)
|
||||
require.NoError(t, err)
|
||||
_, _, _, err = state.ReadMessage(nil, msg2)
|
||||
require.NoError(t, err, "server responder message must decrypt with the correct peer static")
|
||||
}
|
||||
|
||||
// readRFBFailure consumes the RFB version exchange and returns the
|
||||
// security-failure reason string. Fails the test if the server did not
|
||||
// send a failure (i.e. produced a non-zero security-types list).
|
||||
func readRFBFailure(t *testing.T, conn net.Conn) string {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, byte(0), n[0], "expected security-failure (0 types)")
|
||||
|
||||
var rl [4]byte
|
||||
_, err = io.ReadFull(conn, rl[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(rl[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
return string(reason)
|
||||
}
|
||||
|
||||
// readRFBGreetingNoFailure asserts the server proceeded past auth: it
|
||||
// must offer at least one security type rather than a 0 failure.
|
||||
func readRFBGreetingNoFailure(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, byte(0), n[0], "server must offer security types after a valid handshake")
|
||||
}
|
||||
|
||||
// TestNoise_RegisteredKey_AccessGranted exercises the happy path: a
|
||||
// session key enrolled in the authorizer completes a Noise_IK handshake
|
||||
// and the server proceeds to the RFB greeting.
|
||||
func TestNoise_RegisteredKey_AccessGranted(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
readRFBGreetingNoFailure(t, conn)
|
||||
}
|
||||
|
||||
// TestNoise_UnregisteredClientStatic_Rejected proves the authorizer is
|
||||
// consulted: a syntactically-valid handshake from a key the server has
|
||||
// never been told about must be rejected fail-closed.
|
||||
func TestNoise_UnregisteredClientStatic_Rejected(t *testing.T) {
|
||||
addr, _, serverPub := noiseTestServer(t)
|
||||
// Auth is enabled but the authorizer was not updated, so the lookup
|
||||
// path returns ErrSessionKeyNotKnown.
|
||||
attackerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, attackerKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_WrongServerStatic_HandshakeFails proves the server's
|
||||
// identity is bound into the handshake: an initiator using the wrong
|
||||
// peer static encrypts msg1 under keys the real server can't derive, so
|
||||
// the server fails the handshake and closes without RFB output.
|
||||
func TestNoise_WrongServerStatic_HandshakeFails(t *testing.T) {
|
||||
addr, srv, _ := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
bogusServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: bogusServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server must close without RFB greeting when msg1 is sealed for a different server identity")
|
||||
}
|
||||
|
||||
// TestNoise_MalformedMsg1_ClosesConnection covers the case where the
|
||||
// magic prefix is correct but the following 96 bytes are random: the
|
||||
// noise library fails ReadMessage and the server closes silently.
|
||||
func TestNoise_MalformedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
junk := make([]byte, noiseInitiatorMsgLen)
|
||||
for i := range junk {
|
||||
junk[i] = byte(i)
|
||||
}
|
||||
_, err = conn.Write(append([]byte("NBV3"), junk...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "garbage msg1 must terminate the connection before any RFB output")
|
||||
}
|
||||
|
||||
// TestNoise_TruncatedMsg1_ClosesConnection sends fewer than the 96
|
||||
// bytes a Noise_IK msg1 must contain. The server's io.ReadFull short-
|
||||
// reads and closes; no RFB greeting must leak.
|
||||
func TestNoise_TruncatedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
_, err = conn.Write([]byte("NBV3"))
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(make([]byte, 8))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
// Re-dial just to confirm the listener is alive (the previous
|
||||
// connection terminated server-side without affecting the listener).
|
||||
probe, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, probe.Close())
|
||||
}
|
||||
|
||||
// TestNoise_AuthEnabled_NoHandshake_Rejected proves that with auth on,
|
||||
// a connection that skips the Noise prefix (older client / VNC client)
|
||||
// is rejected with AUTH_FORBIDDEN: identity proof missing.
|
||||
func TestNoise_AuthEnabled_NoHandshake_Rejected(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "identity proof missing")
|
||||
}
|
||||
|
||||
// TestNoise_RevokedKey_RejectedAfterAuthUpdate verifies the authorizer
|
||||
// honors revocations: a key that worked before a UpdateVNCAuth call
|
||||
// must stop working as soon as the new config omits it.
|
||||
func TestNoise_RevokedKey_RejectedAfterAuthUpdate(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
// First connection succeeds.
|
||||
conn1, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
writeHeaderPrefix(t, conn1, ModeAttach)
|
||||
performInitiator(t, conn1, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn1)
|
||||
readRFBGreetingNoFailure(t, conn1)
|
||||
|
||||
// Revoke by pushing a fresh config that drops the pubkey entry.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{})
|
||||
|
||||
// Same client, same Noise key, should now be denied.
|
||||
conn2, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
writeHeaderPrefix(t, conn2, ModeAttach)
|
||||
performInitiator(t, conn2, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn2)
|
||||
|
||||
reason := readRFBFailure(t, conn2)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_NoIdentityKey_FailsClosed ensures a server constructed
|
||||
// without a static private key still rejects authenticated connections
|
||||
// fail-closed; it must not silently accept the client.
|
||||
func TestNoise_NoIdentityKey_FailsClosed(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(false)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
fakeServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: fakeServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server without identity key must not write the RFB greeting")
|
||||
}
|
||||
|
||||
// TestNoise_DerivedIdentityPublicMatchesPrivate sanity-checks the
|
||||
// derivation done in New(): the identityPublic must be Curve25519.
|
||||
// Basepoint multiplied with identityKey.
|
||||
func TestNoise_DerivedIdentityPublicMatchesPrivate(t *testing.T) {
|
||||
priv := make([]byte, 32)
|
||||
for i := range priv {
|
||||
priv[i] = byte(i + 1)
|
||||
}
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, priv)
|
||||
|
||||
expected, err := curve25519.X25519(priv, curve25519.Basepoint)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, srv.identityPublic)
|
||||
}
|
||||
|
||||
// TestNoise_SessionMode_OSUserCheckRunsAfterHandshake verifies that a
|
||||
// successful Noise handshake doesn't bypass OS-user authorization: an
|
||||
// authenticated key whose user index isn't mapped to the requested OS
|
||||
// user must be rejected.
|
||||
func TestNoise_SessionMode_OSUserCheckRunsAfterHandshake(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
userHash, err := sshuserhash.HashUserID("alice@example")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Map Alice only to "alice" OS user, not the wildcard.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{"alice": {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: clientKey.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
|
||||
// Request session for "bob" — Noise succeeds, OS-user check denies.
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
bob := []byte("bob")
|
||||
prefix := []byte{ModeSession, 0, byte(len(bob))}
|
||||
prefix = append(prefix, bob...)
|
||||
_, err = conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "authorize OS user")
|
||||
}
|
||||
@@ -3,6 +3,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
@@ -13,16 +15,15 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/flynn/noise"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Connection modes sent by the client in the session header.
|
||||
@@ -35,11 +36,7 @@ const (
|
||||
// stable so clients can branch on them without parsing free text.
|
||||
// Format: "CODE: human message".
|
||||
const (
|
||||
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
|
||||
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
|
||||
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
|
||||
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
|
||||
RejectCodeAuthConfig = "AUTH_CONFIG"
|
||||
RejectCodeSessionError = "SESSION_ERROR"
|
||||
RejectCodeCapturerError = "CAPTURER_ERROR"
|
||||
RejectCodeUnsupportedOS = "UNSUPPORTED"
|
||||
@@ -56,6 +53,21 @@ const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
|
||||
// enough to coalesce bursty multi-session requests. 16 ms ~= 60 fps.
|
||||
const freshWindow = 16 * time.Millisecond
|
||||
|
||||
// maxConcurrentVNCConns caps in-flight VNC connections. Each accepted
|
||||
// connection consumes a handler goroutine, a tracking entry, and (after
|
||||
// handshake) capturer/encoder resources, so an unauthenticated peer that
|
||||
// dials in a tight loop could otherwise grow memory without bound. The
|
||||
// limit covers the entire accept→handshake→session window; a slot is
|
||||
// released only when the handler returns.
|
||||
const maxConcurrentVNCConns = 64
|
||||
|
||||
// maxFramebufferDim caps the screen dimensions accepted from a capturer.
|
||||
// RFB serialises width/height as u16, and the encoder allocates per-frame
|
||||
// buffers proportional to width*height*4. 8192 keeps width*height*4 well
|
||||
// under 2^31 so int math doesn't overflow on 32-bit builds, and is large
|
||||
// enough to cover real-world multi-monitor desktops.
|
||||
const maxFramebufferDim = 8192
|
||||
|
||||
// ScreenCapturer grabs desktop frames for the VNC server.
|
||||
type ScreenCapturer interface {
|
||||
// Width returns the current screen width in pixels.
|
||||
@@ -120,26 +132,22 @@ type InputInjector interface {
|
||||
TypeText(text string)
|
||||
}
|
||||
|
||||
// JWTConfig holds JWT validation configuration for VNC auth.
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
// connectionHeader is sent by the client before the RFB handshake to specify
|
||||
// the VNC session mode and authenticate.
|
||||
type connectionHeader struct {
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
// clientStatic is the client's static X25519 public key learned from
|
||||
// the Noise handshake. Populated when identityVerified is true.
|
||||
clientStatic []byte
|
||||
// sessionID is the Windows session ID; 0 selects the console session.
|
||||
sessionID uint32
|
||||
// width and height request the virtual display geometry for session mode.
|
||||
// Zero means use the default.
|
||||
width uint16
|
||||
height uint16
|
||||
// identityVerified is true when the Noise_IK handshake completed.
|
||||
identityVerified bool
|
||||
}
|
||||
|
||||
// Server is the embedded VNC server that listens on the WireGuard interface.
|
||||
@@ -170,13 +178,16 @@ type Server struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
vmgr virtualSessionManager
|
||||
jwtConfig *JWTConfig
|
||||
jwtValidator *nbjwt.Validator
|
||||
jwtExtractor *nbjwt.ClaimsExtractor
|
||||
authorizer *sshauth.Authorizer
|
||||
netstackNet *netstack.Net
|
||||
authorizer *sshauth.Authorizer
|
||||
netstackNet *netstack.Net
|
||||
// agentToken holds the raw token bytes for agent-mode auth.
|
||||
agentToken []byte
|
||||
// identityKey is the daemon's static X25519 private key used in the
|
||||
// Noise_IK handshake. Nil disables the handshake.
|
||||
identityKey []byte
|
||||
// identityPublic is the matching X25519 public key, derived once at
|
||||
// construction to avoid recomputing per handshake.
|
||||
identityPublic []byte
|
||||
|
||||
sessionsMu sync.Mutex
|
||||
sessionSeq uint64
|
||||
@@ -188,6 +199,17 @@ type Server struct {
|
||||
// closeActiveSessions iterates this set so Stop() can interrupt
|
||||
// handshaking peers, not just post-handshake sessions.
|
||||
acceptedConns map[net.Conn]struct{}
|
||||
// connAuth holds the verified Noise_IK identity tied to each accepted
|
||||
// connection so a later UpdateVNCAuth call can revoke live sessions
|
||||
// whose authorization no longer holds. Populated by registerConnAuth
|
||||
// once authenticateSession succeeds; absent entries (e.g. disableAuth
|
||||
// or pre-handshake conns) are skipped at revocation time.
|
||||
connAuth map[net.Conn]connAuthInfo
|
||||
|
||||
// connSem caps concurrent accepted connections (handshake + session).
|
||||
// Buffered with maxConcurrentVNCConns slots; accept loops try-acquire
|
||||
// before spawning a handler and release on handler return.
|
||||
connSem chan struct{}
|
||||
|
||||
// sessionRecorder, when non-nil, receives a SessionTick periodically
|
||||
// during each VNC session and on session close. The engine wires
|
||||
@@ -195,12 +217,24 @@ type Server struct {
|
||||
sessionRecorder func(SessionTick)
|
||||
}
|
||||
|
||||
// connAuthInfo captures the Noise_IK-verified identity bound to a live
|
||||
// connection so policy updates can re-check it and close sessions whose
|
||||
// authorization was revoked. clientStatic is empty when auth was disabled
|
||||
// for this connection, which signals that revocation does not apply.
|
||||
type connAuthInfo struct {
|
||||
clientStatic []byte
|
||||
mode byte
|
||||
username string
|
||||
}
|
||||
|
||||
// ActiveSessionInfo describes a currently connected VNC client.
|
||||
type ActiveSessionInfo struct {
|
||||
RemoteAddress string
|
||||
Mode string
|
||||
Username string
|
||||
JWTUsername string
|
||||
// UserID is the authenticated session identity (hashed user ID from
|
||||
// the Noise_IK static-key registration), empty when auth is disabled.
|
||||
UserID string
|
||||
}
|
||||
|
||||
// vncSession provides capturer and injector for a virtual display session.
|
||||
@@ -220,19 +254,31 @@ type virtualSessionManager interface {
|
||||
StopAll()
|
||||
}
|
||||
|
||||
// New creates a VNC server with the given screen capturer and input injector.
|
||||
// Authentication uses a JWT supplied by the client in the connection
|
||||
// header; the protocol-level VNC password scheme is not supported.
|
||||
func New(capturer ScreenCapturer, injector InputInjector) *Server {
|
||||
return &Server{
|
||||
// New creates a VNC server. identityKey is the 32-byte X25519 private
|
||||
// key used by the daemon in the Noise_IK handshake; nil disables auth.
|
||||
// The protocol-level VNC password scheme is not supported.
|
||||
func New(capturer ScreenCapturer, injector InputInjector, identityKey []byte) *Server {
|
||||
s := &Server{
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
identityKey: identityKey,
|
||||
authorizer: sshauth.NewAuthorizer(),
|
||||
log: log.WithField("component", "vnc-server"),
|
||||
sessions: make(map[uint64]ActiveSessionInfo),
|
||||
sessionConns: make(map[uint64]net.Conn),
|
||||
acceptedConns: make(map[net.Conn]struct{}),
|
||||
connAuth: make(map[net.Conn]connAuthInfo),
|
||||
connSem: make(chan struct{}, maxConcurrentVNCConns),
|
||||
}
|
||||
if len(identityKey) == 32 {
|
||||
pub, err := curve25519.X25519(identityKey, curve25519.Basepoint)
|
||||
if err == nil {
|
||||
s.identityPublic = pub
|
||||
} else {
|
||||
s.log.Warnf("derive identity public key: %v", err)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ActiveSessions returns a snapshot of currently connected VNC clients.
|
||||
@@ -292,9 +338,75 @@ func (s *Server) trackConn(c net.Conn) {
|
||||
func (s *Server) untrackConn(c net.Conn) {
|
||||
s.sessionsMu.Lock()
|
||||
delete(s.acceptedConns, c)
|
||||
delete(s.connAuth, c)
|
||||
s.sessionsMu.Unlock()
|
||||
}
|
||||
|
||||
// registerConnAuth records the verified Noise_IK identity for a live
|
||||
// connection so UpdateVNCAuth can later revoke it if policy changes.
|
||||
// No-op when auth is disabled (e.g. agent-mode loopback connections).
|
||||
func (s *Server) registerConnAuth(c net.Conn, header *connectionHeader) {
|
||||
if s.disableAuth || header == nil || len(header.clientStatic) != 32 {
|
||||
return
|
||||
}
|
||||
s.sessionsMu.Lock()
|
||||
s.connAuth[c] = connAuthInfo{
|
||||
clientStatic: append([]byte(nil), header.clientStatic...),
|
||||
mode: header.mode,
|
||||
username: header.username,
|
||||
}
|
||||
s.sessionsMu.Unlock()
|
||||
}
|
||||
|
||||
// tryAcquireConnSlot returns true when a connection slot was successfully
|
||||
// reserved. Releases must pair with releaseConnSlot. Returns false when
|
||||
// the cap is already saturated; callers must close the connection.
|
||||
func (s *Server) tryAcquireConnSlot() bool {
|
||||
select {
|
||||
case s.connSem <- struct{}{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) releaseConnSlot() {
|
||||
select {
|
||||
case <-s.connSem:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// revokeUnauthorizedSessions closes every live connection whose Noise-
|
||||
// verified identity no longer authenticates under the current authorizer
|
||||
// configuration. Called by UpdateVNCAuth after the new policy is applied.
|
||||
func (s *Server) revokeUnauthorizedSessions() {
|
||||
if s.disableAuth {
|
||||
return
|
||||
}
|
||||
s.sessionsMu.Lock()
|
||||
victims := make([]net.Conn, 0)
|
||||
for c, info := range s.connAuth {
|
||||
if len(info.clientStatic) != 32 {
|
||||
continue
|
||||
}
|
||||
hdr := &connectionHeader{
|
||||
identityVerified: true,
|
||||
clientStatic: info.clientStatic,
|
||||
mode: info.mode,
|
||||
username: info.username,
|
||||
}
|
||||
if _, err := s.authenticateSession(hdr); err != nil {
|
||||
victims = append(victims, c)
|
||||
s.log.Infof("revoking VNC session from %s: %v", c.RemoteAddr(), err)
|
||||
}
|
||||
}
|
||||
s.sessionsMu.Unlock()
|
||||
for _, c := range victims {
|
||||
_ = c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
|
||||
func (s *Server) SetServiceMode(enabled bool) {
|
||||
s.serviceMode = enabled
|
||||
@@ -308,16 +420,6 @@ func (s *Server) SetSessionRecorder(recorder func(SessionTick)) {
|
||||
s.sessionRecorder = recorder
|
||||
}
|
||||
|
||||
// SetJWTConfig configures JWT authentication for VNC connections.
|
||||
// Pass nil to disable JWT (public mode).
|
||||
func (s *Server) SetJWTConfig(config *JWTConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtConfig = config
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
}
|
||||
|
||||
// SetDisableAuth disables authentication entirely.
|
||||
func (s *Server) SetDisableAuth(disable bool) {
|
||||
s.disableAuth = disable
|
||||
@@ -346,13 +448,14 @@ func (s *Server) SetNetstackNet(n *netstack.Net) {
|
||||
s.netstackNet = n
|
||||
}
|
||||
|
||||
// UpdateVNCAuth updates the fine-grained authorization configuration.
|
||||
// UpdateVNCAuth updates the fine-grained authorization configuration and
|
||||
// closes any live session whose identity no longer authenticates under
|
||||
// the new policy. Revocation is event-driven: there is no periodic
|
||||
// re-check, so a session stays open until either the next UpdateVNCAuth
|
||||
// call or normal disconnect.
|
||||
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
s.authorizer.Update(config)
|
||||
s.revokeUnauthorizedSessions()
|
||||
}
|
||||
|
||||
// Start begins listening for VNC connections on the given address.
|
||||
@@ -463,9 +566,15 @@ func (s *Server) acceptLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleConnection(c)
|
||||
}(conn)
|
||||
@@ -565,16 +674,17 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
if !s.verifyAgentToken(conn, connLog) {
|
||||
return
|
||||
}
|
||||
header, err := readConnectionHeader(conn)
|
||||
header, err := s.readConnectionHeader(conn)
|
||||
if err != nil {
|
||||
connLog.Warnf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
connLog, jwtUserID, ok := s.authorizeJWT(conn, header, connLog)
|
||||
connLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
capturer, injector, sessionCleanup, ok := s.acquireSessionResources(conn, header, &connLog)
|
||||
if !ok {
|
||||
@@ -586,7 +696,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
RemoteAddress: conn.RemoteAddr().String(),
|
||||
Mode: modeString(header.mode),
|
||||
Username: header.username,
|
||||
JWTUsername: jwtUserID,
|
||||
UserID: sessionUserID,
|
||||
}, conn)
|
||||
defer s.removeSession(sessionID)
|
||||
|
||||
@@ -596,13 +706,20 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
w, h := capturer.Width(), capturer.Height()
|
||||
if w <= 0 || h <= 0 || w > maxFramebufferDim || h > maxFramebufferDim {
|
||||
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("framebuffer dimensions out of range: %dx%d", w, h)))
|
||||
connLog.Warnf("rejecting session: framebuffer %dx%d outside [1, %d]", w, h, maxFramebufferDim)
|
||||
return
|
||||
}
|
||||
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
sess := &session{
|
||||
conn: conn,
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
serverW: capturer.Width(),
|
||||
serverH: capturer.Height(),
|
||||
serverW: w,
|
||||
serverH: h,
|
||||
log: connLog,
|
||||
}
|
||||
sess.serve()
|
||||
@@ -615,25 +732,6 @@ func codeMessage(code, msg string) string {
|
||||
return code + ": " + msg
|
||||
}
|
||||
|
||||
// jwtErrorCode maps a JWT auth error to a stable reject code.
|
||||
func jwtErrorCode(err error) string {
|
||||
if err == nil {
|
||||
return RejectCodeJWTInvalid
|
||||
}
|
||||
if errors.Is(err, nbjwt.ErrTokenExpired) {
|
||||
return RejectCodeJWTExpired
|
||||
}
|
||||
msg := err.Error()
|
||||
switch {
|
||||
case strings.Contains(msg, "JWT required but not provided"):
|
||||
return RejectCodeJWTMissing
|
||||
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
|
||||
return RejectCodeAuthForbidden
|
||||
default:
|
||||
return RejectCodeJWTInvalid
|
||||
}
|
||||
}
|
||||
|
||||
// rejectConnection sends a minimal RFB handshake with a security failure
|
||||
// reason, so VNC clients display the error message instead of a generic
|
||||
// "unexpected disconnect."
|
||||
@@ -658,105 +756,57 @@ func rejectConnection(conn net.Conn, reason string) {
|
||||
_, _ = conn.Write(buf)
|
||||
}
|
||||
|
||||
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
|
||||
|
||||
// authenticateJWT validates the JWT from the connection header and checks
|
||||
// authorization. For attach mode, just checks membership in the authorized
|
||||
// user list. For session mode, additionally validates the OS user mapping.
|
||||
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
|
||||
if header.jwt == "" {
|
||||
return "", fmt.Errorf("JWT required but not provided")
|
||||
// authenticateSession resolves the Noise-verified client static public
|
||||
// key to a hashed user identity via the authorizer, and checks OS-user
|
||||
// mapping for session mode. Returns the hashed user identity on success.
|
||||
func (s *Server) authenticateSession(header *connectionHeader) (string, error) {
|
||||
if !header.identityVerified {
|
||||
return "", fmt.Errorf("identity proof missing")
|
||||
}
|
||||
if len(header.clientStatic) != 32 {
|
||||
return "", fmt.Errorf("client static key missing")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("initialize JWT validator: %w", err)
|
||||
}
|
||||
validator := s.jwtValidator
|
||||
extractor := s.jwtExtractor
|
||||
s.mu.Unlock()
|
||||
|
||||
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
|
||||
userIDHash, err := s.authorizer.LookupSessionKey(header.clientStatic)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validate JWT: %w", err)
|
||||
return "", fmt.Errorf("lookup session pubkey: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token); err != nil {
|
||||
return "", err
|
||||
osUser := "*"
|
||||
if header.mode == ModeSession {
|
||||
osUser = header.username
|
||||
}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extract user from JWT: %w", err)
|
||||
if _, err := s.authorizer.AuthorizeOSUserBySessionKey(userIDHash, osUser); err != nil {
|
||||
return "", fmt.Errorf("authorize OS user %q: %w", osUser, err)
|
||||
}
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("JWT has no user ID")
|
||||
}
|
||||
|
||||
switch header.mode {
|
||||
case ModeSession:
|
||||
// Session mode: check user + OS username mapping.
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
|
||||
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
|
||||
}
|
||||
default:
|
||||
// Attach mode: just check user is in the authorized list (wildcard OS user).
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
|
||||
return "", fmt.Errorf("user not authorized for VNC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return userAuth.UserId, nil
|
||||
return userIDHash.String(), nil
|
||||
}
|
||||
|
||||
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
if s.jwtConfig == nil {
|
||||
return fmt.Errorf("no JWT config")
|
||||
}
|
||||
var vncIdentityMagic = []byte("NBV3")
|
||||
|
||||
// Enable IdP key refresh so JWKS rotations don't latch the validator
|
||||
// off until daemon restart.
|
||||
s.jwtValidator = nbjwt.NewValidator(
|
||||
s.jwtConfig.Issuer,
|
||||
s.jwtConfig.Audiences,
|
||||
s.jwtConfig.KeysLocation,
|
||||
true,
|
||||
)
|
||||
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
|
||||
// msg1 = e(32) + s_AEAD(32+16) + payload_AEAD(0+16) = 96 bytes
|
||||
// msg2 = e(32) + payload_AEAD(0+16) = 48 bytes
|
||||
const (
|
||||
noiseInitiatorMsgLen = 96
|
||||
noiseResponderMsgLen = 48
|
||||
)
|
||||
|
||||
var opts []nbjwt.ClaimsExtractorOption
|
||||
if len(s.jwtConfig.Audiences) > 0 {
|
||||
opts = append(opts, nbjwt.WithAudience(s.jwtConfig.Audiences[0]))
|
||||
}
|
||||
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
|
||||
opts = append(opts, nbjwt.WithUserIDClaim(claim))
|
||||
}
|
||||
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
|
||||
// vncNoiseSuite pins the cipher suite for the VNC handshake. Changing
|
||||
// it requires bumping vncIdentityMagic so old clients fail closed.
|
||||
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token) error {
|
||||
maxAge := defaultJWTMaxTokenAge
|
||||
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
|
||||
maxAge = int(s.jwtConfig.MaxTokenAge)
|
||||
}
|
||||
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
|
||||
}
|
||||
|
||||
// readConnectionHeader reads the NetBird VNC session header from the connection.
|
||||
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||
// readConnectionHeader reads the NetBird VNC session header. Format:
|
||||
//
|
||||
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||
// [mode: 1] [username_len: 2 BE] [username: N]
|
||||
// [opt magic "NBV3": 4] [noise_msg1: 96]
|
||||
// (server writes [noise_msg2: 48] here when the magic is present)
|
||||
// [session_id: 4 BE] [width: 2 BE] [height: 2 BE]
|
||||
//
|
||||
// Uses a short timeout: our WASM proxy sends the header immediately after
|
||||
// connecting. Standard VNC clients don't send anything first (server speaks
|
||||
// first in RFB), so they time out and get the default attach mode.
|
||||
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
// Standard VNC clients don't speak first, so they time out on the first
|
||||
// read and fall through to attach mode (which auth still rejects when
|
||||
// no Noise handshake completed).
|
||||
func (s *Server) readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
@@ -764,11 +814,9 @@ func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
|
||||
var hdr [3]byte
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
// Timeout or error: assume no header, use attach mode.
|
||||
return &connectionHeader{mode: ModeAttach}, nil
|
||||
}
|
||||
|
||||
// Restore a longer deadline for reading variable-length fields.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
@@ -788,48 +836,93 @@ func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
username = string(buf)
|
||||
}
|
||||
|
||||
// Read JWT token length and data.
|
||||
var jwtLenBuf [2]byte
|
||||
var jwtToken string
|
||||
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
|
||||
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
|
||||
if jwtLen >= 8192 {
|
||||
return nil, fmt.Errorf("jwt too long: %d (max 8191)", jwtLen)
|
||||
}
|
||||
if jwtLen > 0 {
|
||||
buf := make([]byte, jwtLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("read JWT: %w", err)
|
||||
}
|
||||
jwtToken = string(buf)
|
||||
}
|
||||
br := bufio.NewReader(conn)
|
||||
clientStatic, identityVerified, err := s.maybeRunNoiseHandshake(conn, br)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
|
||||
var sessionID uint32
|
||||
var sidBuf [4]byte
|
||||
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
|
||||
if _, err := io.ReadFull(br, sidBuf[:]); err == nil {
|
||||
sessionID = binary.BigEndian.Uint32(sidBuf[:])
|
||||
}
|
||||
|
||||
// Read optional requested viewport size (2x uint16 BE). Missing = 0 (default).
|
||||
var width, height uint16
|
||||
var geomBuf [4]byte
|
||||
if _, err := io.ReadFull(conn, geomBuf[:]); err == nil {
|
||||
if _, err := io.ReadFull(br, geomBuf[:]); err == nil {
|
||||
width = binary.BigEndian.Uint16(geomBuf[0:2])
|
||||
height = binary.BigEndian.Uint16(geomBuf[2:4])
|
||||
}
|
||||
|
||||
return &connectionHeader{
|
||||
mode: mode,
|
||||
username: username,
|
||||
jwt: jwtToken,
|
||||
sessionID: sessionID,
|
||||
width: width,
|
||||
height: height,
|
||||
mode: mode,
|
||||
username: username,
|
||||
clientStatic: clientStatic,
|
||||
sessionID: sessionID,
|
||||
width: width,
|
||||
height: height,
|
||||
identityVerified: identityVerified,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// maybeRunNoiseHandshake performs the responder side of a Noise_IK
|
||||
// handshake when the client sends the v3 magic. Returns the client static
|
||||
// public key learned from the handshake. Any handshake failure is fatal
|
||||
// (fail closed).
|
||||
func (s *Server) maybeRunNoiseHandshake(conn net.Conn, br *bufio.Reader) ([]byte, bool, error) {
|
||||
peek, err := br.Peek(len(vncIdentityMagic))
|
||||
if err != nil || !bytes.Equal(peek, vncIdentityMagic) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if _, err := br.Discard(len(vncIdentityMagic)); err != nil {
|
||||
return nil, false, fmt.Errorf("discard identity magic: %w", err)
|
||||
}
|
||||
|
||||
msg1 := make([]byte, noiseInitiatorMsgLen)
|
||||
if _, err := io.ReadFull(br, msg1); err != nil {
|
||||
return nil, false, fmt.Errorf("read noise msg1: %w", err)
|
||||
}
|
||||
|
||||
// Agents on loopback authenticate via the agent token, not this
|
||||
// handshake. Consume the replayed bytes and skip the response.
|
||||
if s.disableAuth {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
if len(s.identityKey) != 32 || len(s.identityPublic) != 32 {
|
||||
return nil, false, errors.New("identity key not configured")
|
||||
}
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: false,
|
||||
StaticKeypair: noise.DHKey{Private: s.identityKey, Public: s.identityPublic},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("noise responder init: %w", err)
|
||||
}
|
||||
if _, _, _, err := state.ReadMessage(nil, msg1); err != nil {
|
||||
return nil, false, fmt.Errorf("noise read msg1: %w", err)
|
||||
}
|
||||
msg2, _, _, err := state.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("noise write msg2: %w", err)
|
||||
}
|
||||
if len(msg2) != noiseResponderMsgLen {
|
||||
return nil, false, fmt.Errorf("noise responder produced %d bytes, expected %d", len(msg2), noiseResponderMsgLen)
|
||||
}
|
||||
if _, err := conn.Write(msg2); err != nil {
|
||||
return nil, false, fmt.Errorf("write noise msg2: %w", err)
|
||||
}
|
||||
|
||||
clientStatic := state.PeerStatic()
|
||||
if len(clientStatic) != 32 {
|
||||
return nil, false, errors.New("noise peer static missing")
|
||||
}
|
||||
return clientStatic, true, nil
|
||||
}
|
||||
|
||||
// verifyAgentToken validates the agent token prefix when configured. Returns
|
||||
// false when the token is invalid or unreadable; the connection is closed.
|
||||
func (s *Server) verifyAgentToken(conn net.Conn, connLog *log.Entry) bool {
|
||||
@@ -865,25 +958,34 @@ func (s *Server) verifyAgentToken(conn net.Conn, connLog *log.Entry) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// authorizeJWT performs JWT validation when auth is enabled. Returns the
|
||||
// enriched log entry, jwt user ID (empty when auth disabled), and ok=false
|
||||
// if the connection was rejected.
|
||||
func (s *Server) authorizeJWT(conn net.Conn, header *connectionHeader, connLog *log.Entry) (*log.Entry, string, bool) {
|
||||
// authorizeSession runs the Noise_IK handshake when auth is enabled.
|
||||
// Returns the enriched log entry, user identity hash (empty when auth
|
||||
// disabled), and ok=false if the connection was rejected.
|
||||
func (s *Server) authorizeSession(conn net.Conn, header *connectionHeader, connLog *log.Entry) (*log.Entry, string, bool) {
|
||||
if s.disableAuth {
|
||||
return connLog, "", true
|
||||
}
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return connLog, "", false
|
||||
}
|
||||
jwtUserID, err := s.authenticateJWT(header)
|
||||
userID, err := s.authenticateSession(header)
|
||||
if err != nil {
|
||||
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return connLog, "", false
|
||||
}
|
||||
return connLog.WithField("jwt_user", jwtUserID), jwtUserID, true
|
||||
return connLog.WithFields(log.Fields{
|
||||
"session_user": userID,
|
||||
"session_key": sessionKeyFingerprint(header.clientStatic),
|
||||
}), userID, true
|
||||
}
|
||||
|
||||
// sessionKeyFingerprint returns a short hex fingerprint of a client
|
||||
// static key for log correlation. Distinct VNC sessions of the same
|
||||
// user end up with distinct fingerprints because each session mints a
|
||||
// fresh keypair, so this lets an operator tell parallel sessions apart.
|
||||
func sessionKeyFingerprint(clientStatic []byte) string {
|
||||
if len(clientStatic) < 4 {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(clientStatic[:4])
|
||||
}
|
||||
|
||||
// acquireSessionResources returns the capturer/injector to use for this
|
||||
|
||||
@@ -47,10 +47,16 @@ func (s *Server) serviceAcceptLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnectionDarwin(c, mgr)
|
||||
}(conn)
|
||||
@@ -69,7 +75,7 @@ func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentMa
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &darwinPrefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := readConnectionHeader(teeConn)
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
@@ -77,17 +83,13 @@ func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentMa
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
if _, err := s.authenticateJWT(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
token, err := mgr.ensure(s.ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -26,14 +25,11 @@ func (t *testCapturer) Capture() (*image.RGBA, error) {
|
||||
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
|
||||
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
if jwtConfig != nil {
|
||||
srv.SetJWTConfig(jwtConfig)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
@@ -45,30 +41,28 @@ func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.
|
||||
return srv.listener.Addr(), srv
|
||||
}
|
||||
|
||||
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false, nil)
|
||||
func TestAuthEnabled_NoSessionAuth_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
||||
// Header with no Noise handshake. Auth-required servers must reject
|
||||
// because no client static was authenticated.
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version then security failure.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version to proceed through handshake.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read security types: 0 means failure, followed by reason.
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
@@ -81,18 +75,17 @@ func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
|
||||
assert.Contains(t, string(reason), "identity proof missing", "rejection reason should mention missing identity proof")
|
||||
}
|
||||
|
||||
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, true, nil)
|
||||
addr, _ := startTestServer(t, true)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
@@ -114,70 +107,12 @@ func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||
}
|
||||
|
||||
// TestAuthEnabled_InvalidJWT_RejectedBeforeRFB confirms the VNC server itself
|
||||
// (not just the JWT library) wires authentication into handleConnection. A
|
||||
// well-formed JWT-shaped token must hit the server's validation path and be
|
||||
// rejected with an AUTH_JWT_* reason, never reaching the RFB handshake.
|
||||
func TestAuthEnabled_InvalidJWT_RejectedBeforeRFB(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false, &JWTConfig{
|
||||
Issuer: "https://example.invalid",
|
||||
KeysLocation: "https://example.invalid/.well-known/jwks.json",
|
||||
Audiences: []string{"test"},
|
||||
})
|
||||
|
||||
// Three-segment "JWT" with bogus base64. The server's authenticateJWT path
|
||||
// must catch this regardless of the IdP being unreachable.
|
||||
bogusJWT := "abc.def.ghi"
|
||||
header := make([]byte, 3+2+len(bogusJWT)+4+4)
|
||||
header[0] = ModeAttach
|
||||
binary.BigEndian.PutUint16(header[1:3], 0) // username len
|
||||
binary.BigEndian.PutUint16(header[3:5], uint16(len(bogusJWT)))
|
||||
copy(header[5:5+len(bogusJWT)], bogusJWT)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, byte(0), numTypes[0], "must fail security negotiation")
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
// The reason must carry one of the server's AUTH_JWT_* codes, proving
|
||||
// the rejection came from authenticateJWT in handleConnection.
|
||||
r := string(reason)
|
||||
hasJWTReject := false
|
||||
for _, code := range []string{RejectCodeJWTInvalid, RejectCodeJWTExpired, RejectCodeAuthForbidden} {
|
||||
if strings.Contains(r, code) {
|
||||
hasJWTReject = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasJWTReject, "reason %q must include an AUTH_JWT_* code", r)
|
||||
}
|
||||
|
||||
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
|
||||
// content to a connection that fails source validation. Specifically, the
|
||||
// server must close immediately and the client must see EOF before any RFB
|
||||
// version greeting is written.
|
||||
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
||||
@@ -198,37 +133,6 @@ func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
||||
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
|
||||
}
|
||||
|
||||
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
|
||||
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
|
||||
addr, _ := startTestServer(t, false, &JWTConfig{
|
||||
Issuer: "https://example.com",
|
||||
KeysLocation: "https://example.com/.well-known/jwks.json",
|
||||
Audiences: []string{"test"},
|
||||
})
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header with empty JWT.
|
||||
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
|
||||
}
|
||||
|
||||
func TestIsAllowedSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -289,7 +193,7 @@ func TestIsAllowedSource(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.localAddr = tc.localAddr
|
||||
srv.network = tc.network
|
||||
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
||||
@@ -298,7 +202,7 @@ func TestIsAllowedSource(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
||||
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
||||
@@ -306,7 +210,7 @@ func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken("deadbeefcafebabe")
|
||||
|
||||
@@ -334,7 +238,7 @@ func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
const tokenHex = "deadbeefcafebabe"
|
||||
srv.SetAgentToken(tokenHex)
|
||||
@@ -356,7 +260,7 @@ func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send session header so handleConnection can proceed past readConnectionHeader.
|
||||
header := make([]byte, 13) // ModeAttach + usernameLen=0 + jwtLen=0 + sessionID=0 + width=0 + height=0
|
||||
header := make([]byte, 11) // ModeAttach + usernameLen=0 + sessionID=0 + width=0 + height=0
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
@@ -371,7 +275,7 @@ func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
||||
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
||||
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
||||
srv := New(&testCapturer{}, &StubInputInjector{})
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
@@ -387,7 +291,7 @@ func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
// ModeSession with no username/JWT, so we exit on the vmgr==nil branch
|
||||
// ModeSession with no username, so we exit on the vmgr==nil branch
|
||||
// before username validation runs.
|
||||
header := []byte{ModeSession, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
|
||||
@@ -233,8 +233,9 @@ func (s *Server) platformInit() {
|
||||
startSASListener(s.ctx)
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in Session 0. It validates source IP and
|
||||
// authenticates via JWT before proxying connections to the user-session agent.
|
||||
// serviceAcceptLoop runs in Session 0. It validates the source IP and
|
||||
// hands accepted connections to handleServiceConnection, which runs the
|
||||
// Noise_IK handshake before proxying to the user-session agent.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
|
||||
sm := newSessionManager(agentPort)
|
||||
@@ -255,18 +256,25 @@ func (s *Server) serviceAcceptLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnection(c, sm)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceConnection validates the source IP and JWT, then proxies
|
||||
// the connection (with header bytes replayed) to the agent.
|
||||
// handleServiceConnection runs the connection-header handshake (including
|
||||
// Noise_IK), then proxies the connection (with header bytes replayed) to
|
||||
// the agent listening on loopback.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
@@ -279,7 +287,7 @@ func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := readConnectionHeader(teeConn)
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
@@ -287,17 +295,13 @@ func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
if _, err := s.authenticateJWT(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
// Replay buffered header bytes + remaining stream to the agent.
|
||||
replayConn := &prefixConn{
|
||||
|
||||
@@ -222,9 +222,9 @@ func (s *session) handshake() error {
|
||||
}
|
||||
|
||||
// sendSecurityTypes advertises only secNone. Authentication and access
|
||||
// control happen in the NetBird connection header (JWT, mode, username)
|
||||
// that precedes the RFB handshake, not via the protocol-level password
|
||||
// scheme.
|
||||
// control happen in the NetBird connection header (Noise_IK handshake,
|
||||
// mode, username) that precedes the RFB handshake; the protocol-level
|
||||
// password scheme is not supported.
|
||||
func (s *session) sendSecurityTypes() error {
|
||||
_, err := s.conn.Write([]byte{1, secNone})
|
||||
return err
|
||||
|
||||
@@ -225,6 +225,10 @@ func (s *session) handleResize() error {
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil
|
||||
}
|
||||
if w > maxFramebufferDim || h > maxFramebufferDim {
|
||||
s.log.Warnf("ignoring resize: %dx%d exceeds cap %d", w, h, maxFramebufferDim)
|
||||
return nil
|
||||
}
|
||||
if w == s.serverW && h == s.serverH {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -39,6 +40,7 @@ const (
|
||||
|
||||
func main() {
|
||||
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||
js.Global().Set("netbirdGenerateVNCSessionKey", createGenerateVNCSessionKeyMethod())
|
||||
|
||||
select {}
|
||||
}
|
||||
@@ -388,13 +390,31 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createGenerateVNCSessionKeyMethod returns a JS func that mints a fresh
|
||||
// X25519 keypair, stashes the private half inside wasm under a random
|
||||
// session id, and returns { publicKey, sessionId } to JS. The private
|
||||
// key never leaves the wasm heap.
|
||||
func createGenerateVNCSessionKeyMethod() js.Func {
|
||||
return js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
id, pub, err := vnc.NewSessionKey()
|
||||
if err != nil {
|
||||
return js.ValueOf(err.Error())
|
||||
}
|
||||
out := js.Global().Get("Object").New()
|
||||
out.Set("sessionId", id)
|
||||
out.Set("publicKey", base64.StdEncoding.EncodeToString(pub))
|
||||
return out
|
||||
})
|
||||
}
|
||||
|
||||
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?, width?, height?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// jwt: authentication token (from OIDC session)
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
// width/height: requested viewport size for session mode (0 = server default)
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, keySessionID?, sessionID?, width?, height?, peerPublicKey?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// keySessionID: handle for the wasm-resident session keypair minted by netbirdGenerateVNCSessionKey
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
// width/height: requested viewport size for session mode (0 = server default)
|
||||
// peerPublicKey: base64 X25519 static pubkey of the destination peer (required for auth)
|
||||
func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
params, err := parseVNCProxyArgs(args)
|
||||
@@ -408,14 +428,15 @@ func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
proxy := vnc.NewVNCProxy(client)
|
||||
return proxy.CreateProxy(vnc.ProxyRequest{
|
||||
Hostname: params.hostname,
|
||||
Port: params.port,
|
||||
Mode: params.mode,
|
||||
Username: params.username,
|
||||
JWT: params.jwt,
|
||||
SessionID: params.sessionID,
|
||||
Width: params.width,
|
||||
Height: params.height,
|
||||
Hostname: params.hostname,
|
||||
Port: params.port,
|
||||
Mode: params.mode,
|
||||
Username: params.username,
|
||||
SessionID: params.sessionID,
|
||||
Width: params.width,
|
||||
Height: params.height,
|
||||
PeerPublicKey: params.peerPublicKey,
|
||||
KeySessionID: params.keySessionID,
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -425,11 +446,12 @@ type vncProxyParams struct {
|
||||
port string
|
||||
mode string
|
||||
username string
|
||||
jwt string
|
||||
keySessionID string
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
rejectViaPromise bool // true when the JS caller expects a rejected Promise instead of a plain string return
|
||||
peerPublicKey string
|
||||
rejectViaPromise bool
|
||||
}
|
||||
|
||||
// parseVNCProxyArgs validates JS args for createVNCProxyMethod and returns
|
||||
@@ -480,7 +502,7 @@ func parseVNCProxyOptionalStrings(args []js.Value, p *vncProxyParams) error {
|
||||
p.username = args[3].String()
|
||||
}
|
||||
if len(args) > 4 && args[4].Type() == js.TypeString {
|
||||
p.jwt = args[4].String()
|
||||
p.keySessionID = args[4].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -512,6 +534,9 @@ func parseVNCProxyOptionalNumbers(args []js.Value, p *vncProxyParams) error {
|
||||
}
|
||||
p.height = uint16(v)
|
||||
}
|
||||
if len(args) > 8 && args[8].Type() == js.TypeString {
|
||||
p.peerPublicKey = args[8].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ package vnc
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,9 +15,65 @@ import (
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var cryptoRandRead = crand.Read
|
||||
|
||||
// vncIdentityMagic mirrors the server side in client/vnc/server/server.go.
|
||||
var vncIdentityMagic = []byte("NBV3")
|
||||
|
||||
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
|
||||
const (
|
||||
noiseInitiatorMsgLen = 96
|
||||
noiseResponderMsgLen = 48
|
||||
)
|
||||
|
||||
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
// sessionKeyStore retains per-session X25519 keypairs so the JS layer
|
||||
// only sees an opaque session id + the public key; the private key never
|
||||
// leaves wasm.
|
||||
var sessionKeyStore = struct {
|
||||
mu sync.Mutex
|
||||
keys map[string]noise.DHKey
|
||||
}{keys: map[string]noise.DHKey{}}
|
||||
|
||||
// NewSessionKey mints an X25519 keypair, stores the private half under a
|
||||
// fresh random session id, and returns (id, pubkey).
|
||||
func NewSessionKey() (string, []byte, error) {
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate keypair: %w", err)
|
||||
}
|
||||
idBytes := make([]byte, 16)
|
||||
if _, err := cryptoRandRead(idBytes); err != nil {
|
||||
return "", nil, fmt.Errorf("session id randomness: %w", err)
|
||||
}
|
||||
id := base64.RawURLEncoding.EncodeToString(idBytes)
|
||||
sessionKeyStore.mu.Lock()
|
||||
sessionKeyStore.keys[id] = kp
|
||||
sessionKeyStore.mu.Unlock()
|
||||
return id, kp.Public, nil
|
||||
}
|
||||
|
||||
// lookupSessionKey returns the keypair for id, or false if unknown.
|
||||
func lookupSessionKey(id string) (noise.DHKey, bool) {
|
||||
sessionKeyStore.mu.Lock()
|
||||
defer sessionKeyStore.mu.Unlock()
|
||||
kp, ok := sessionKeyStore.keys[id]
|
||||
return kp, ok
|
||||
}
|
||||
|
||||
// dropSessionKey removes the keypair for id. Called after the VNC
|
||||
// connection closes (or after a connect attempt fails terminally).
|
||||
func dropSessionKey(id string) {
|
||||
sessionKeyStore.mu.Lock()
|
||||
delete(sessionKeyStore.keys, id)
|
||||
sessionKeyStore.mu.Unlock()
|
||||
}
|
||||
|
||||
const (
|
||||
vncProxyHost = "vnc.proxy.local"
|
||||
vncProxyScheme = "ws"
|
||||
@@ -37,10 +95,12 @@ const (
|
||||
|
||||
// VNCProxy bridges WebSocket connections from noVNC in the browser
|
||||
// to TCP VNC server connections through the NetBird tunnel.
|
||||
type vncNBClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type VNCProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
nbClient vncNBClient
|
||||
activeConnections map[string]*vncConnection
|
||||
destinations map[string]vncDestination
|
||||
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
|
||||
@@ -52,13 +112,15 @@ type VNCProxy struct {
|
||||
}
|
||||
|
||||
type vncDestination struct {
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
sessionID uint32 // Windows session ID (0 = auto/console)
|
||||
width uint16 // Requested viewport width for session mode (0 = default)
|
||||
height uint16 // Requested viewport height for session mode (0 = default)
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
sessionPriv []byte
|
||||
sessionPub []byte
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
peerPubKey []byte
|
||||
}
|
||||
|
||||
type vncConnection struct {
|
||||
@@ -78,9 +140,7 @@ type vncConnection struct {
|
||||
}
|
||||
|
||||
// NewVNCProxy creates a new VNC proxy.
|
||||
func NewVNCProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *VNCProxy {
|
||||
func NewVNCProxy(client vncNBClient) *VNCProxy {
|
||||
return &VNCProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*vncConnection),
|
||||
@@ -94,10 +154,16 @@ type ProxyRequest struct {
|
||||
Port string
|
||||
Mode string
|
||||
Username string
|
||||
JWT string
|
||||
SessionID uint32
|
||||
Width uint16
|
||||
Height uint16
|
||||
// PeerPublicKey is the destination peer's base64 X25519 public key,
|
||||
// used as the responder static in the Noise_IK handshake.
|
||||
PeerPublicKey string
|
||||
// KeySessionID is the handle returned by generateVNCSessionKey. The
|
||||
// matching private key is looked up inside wasm and never crosses
|
||||
// the JS boundary.
|
||||
KeySessionID string
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given VNC destination.
|
||||
@@ -106,7 +172,7 @@ type ProxyRequest struct {
|
||||
// virtual display geometry for session mode; 0 means use the server default.
|
||||
// Returns a JS Promise that resolves to the WebSocket proxy URL.
|
||||
func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
|
||||
hostname, port, mode, username, jwt := req.Hostname, req.Port, req.Mode, req.Username, req.JWT
|
||||
hostname, port, mode, username := req.Hostname, req.Port, req.Mode, req.Username
|
||||
sessionID, width, height := req.SessionID, req.Width, req.Height
|
||||
address := net.JoinHostPort(hostname, port)
|
||||
|
||||
@@ -119,14 +185,51 @@ func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
|
||||
address: address,
|
||||
mode: m,
|
||||
username: username,
|
||||
jwt: jwt,
|
||||
sessionID: sessionID,
|
||||
width: width,
|
||||
height: height,
|
||||
}
|
||||
if req.KeySessionID != "" {
|
||||
kp, ok := lookupSessionKey(req.KeySessionID)
|
||||
if !ok {
|
||||
return rejectedPromise("unknown VNC session id")
|
||||
}
|
||||
// A session handle is single-use; drop it before the destination
|
||||
// holds the private bytes so a leaked handle can't be replayed.
|
||||
dropSessionKey(req.KeySessionID)
|
||||
dest.sessionPriv = kp.Private
|
||||
dest.sessionPub = kp.Public
|
||||
pub, err := decodePeerPubKey(req.PeerPublicKey)
|
||||
if err != nil {
|
||||
return rejectedPromise(fmt.Sprintf("invalid peer public key: %v", err))
|
||||
}
|
||||
dest.peerPubKey = pub
|
||||
}
|
||||
return p.newProxyPromise(address, mode, username, dest)
|
||||
}
|
||||
|
||||
// decodePeerPubKey parses a base64-encoded 32-byte X25519 public key.
|
||||
func decodePeerPubKey(b64 string) ([]byte, error) {
|
||||
if b64 == "" {
|
||||
return nil, errors.New("peer public key missing")
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
if len(raw) != 32 {
|
||||
return nil, fmt.Errorf("expected 32 bytes, got %d", len(raw))
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// rejectedPromise returns a resolved Promise carrying msg as an error
|
||||
// string, mirroring how CreateProxy reports earlier validation failures.
|
||||
func rejectedPromise(msg string) js.Value {
|
||||
promise := js.Global().Get("Promise")
|
||||
return promise.Call("resolve", js.ValueOf(msg))
|
||||
}
|
||||
|
||||
// newProxyPromise wraps the JS Promise creation + executor lifecycle so
|
||||
// CreateProxy stays a thin parameter-bundling entrypoint.
|
||||
func (p *VNCProxy) newProxyPromise(address, mode, username string, dest vncDestination) js.Value {
|
||||
@@ -288,46 +391,95 @@ func (p *VNCProxy) connectToVNC(conn *vncConnection) {
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
// sendSessionHeader writes mode, username, JWT, Windows session ID, and the
|
||||
// requested viewport size to the VNC server.
|
||||
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N]
|
||||
//
|
||||
// [session_id:4] [width:2] [height:2]
|
||||
// sendSessionHeader writes the NetBird VNC connection header: mode +
|
||||
// username prefix, an optional Noise_IK handshake that authenticates the
|
||||
// client and the server, then the trailing sessionID / width / height
|
||||
// fields the daemon needs once auth is settled.
|
||||
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
|
||||
usernameBytes := []byte(dest.username)
|
||||
jwtBytes := []byte(dest.jwt)
|
||||
if len(usernameBytes) > 0xFFFF {
|
||||
return fmt.Errorf("username too long: %d bytes (max %d)", len(usernameBytes), 0xFFFF)
|
||||
}
|
||||
if len(jwtBytes) > 0xFFFF {
|
||||
return fmt.Errorf("jwt too long: %d bytes (max %d)", len(jwtBytes), 0xFFFF)
|
||||
prefix := make([]byte, 3+len(usernameBytes))
|
||||
prefix[0] = dest.mode
|
||||
prefix[1] = byte(len(usernameBytes) >> 8)
|
||||
prefix[2] = byte(len(usernameBytes))
|
||||
copy(prefix[3:], usernameBytes)
|
||||
if err := writeAll(conn, prefix); err != nil {
|
||||
return fmt.Errorf("write header prefix: %w", err)
|
||||
}
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4+4)
|
||||
hdr[0] = dest.mode
|
||||
hdr[1] = byte(len(usernameBytes) >> 8)
|
||||
hdr[2] = byte(len(usernameBytes))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
hdr[off] = byte(len(jwtBytes) >> 8)
|
||||
hdr[off+1] = byte(len(jwtBytes))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
off += len(jwtBytes)
|
||||
hdr[off] = byte(dest.sessionID >> 24)
|
||||
hdr[off+1] = byte(dest.sessionID >> 16)
|
||||
hdr[off+2] = byte(dest.sessionID >> 8)
|
||||
hdr[off+3] = byte(dest.sessionID)
|
||||
off += 4
|
||||
hdr[off] = byte(dest.width >> 8)
|
||||
hdr[off+1] = byte(dest.width)
|
||||
hdr[off+2] = byte(dest.height >> 8)
|
||||
hdr[off+3] = byte(dest.height)
|
||||
|
||||
for off := 0; off < len(hdr); {
|
||||
n, err := conn.Write(hdr[off:])
|
||||
if dest.sessionPriv == nil {
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
if err := p.runNoiseHandshake(conn, dest); err != nil {
|
||||
return fmt.Errorf("noise handshake: %w", err)
|
||||
}
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the post-auth trailing fields (sessionID,
|
||||
// width, height) the daemon reads regardless of whether the Noise
|
||||
// handshake was performed.
|
||||
func (p *VNCProxy) writeHeaderTail(conn net.Conn, dest vncDestination) error {
|
||||
tail := make([]byte, 4+4)
|
||||
tail[0] = byte(dest.sessionID >> 24)
|
||||
tail[1] = byte(dest.sessionID >> 16)
|
||||
tail[2] = byte(dest.sessionID >> 8)
|
||||
tail[3] = byte(dest.sessionID)
|
||||
tail[4] = byte(dest.width >> 8)
|
||||
tail[5] = byte(dest.width)
|
||||
tail[6] = byte(dest.height >> 8)
|
||||
tail[7] = byte(dest.height)
|
||||
if err := writeAll(conn, tail); err != nil {
|
||||
return fmt.Errorf("write header tail: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runNoiseHandshake performs the initiator side of a Noise_IK handshake
|
||||
// against the destination daemon. The session keypair authenticates the
|
||||
// client; the daemon's pre-known peer pubkey authenticates the server.
|
||||
func (p *VNCProxy) runNoiseHandshake(conn net.Conn, dest vncDestination) error {
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: noise.DHKey{Private: dest.sessionPriv, Public: dest.sessionPub},
|
||||
PeerStatic: dest.peerPubKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise initiator init: %w", err)
|
||||
}
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise write msg1: %w", err)
|
||||
}
|
||||
out := make([]byte, 0, len(vncIdentityMagic)+len(msg1))
|
||||
out = append(out, vncIdentityMagic...)
|
||||
out = append(out, msg1...)
|
||||
if err := writeAll(conn, out); err != nil {
|
||||
return fmt.Errorf("send noise msg1: %w", err)
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return fmt.Errorf("set noise deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
if _, err := io.ReadFull(conn, msg2); err != nil {
|
||||
return fmt.Errorf("read noise msg2: %w", err)
|
||||
}
|
||||
if _, _, _, err := state.ReadMessage(nil, msg2); err != nil {
|
||||
return fmt.Errorf("noise read msg2: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAll(conn net.Conn, buf []byte) error {
|
||||
for off := 0; off < len(buf); {
|
||||
n, err := conn.Write(buf[off:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write session header: %w", err)
|
||||
return err
|
||||
}
|
||||
off += n
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -185,6 +185,7 @@ require (
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -162,6 +162,8 @@ github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
|
||||
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
@@ -412,6 +414,7 @@ github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoK
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
@@ -756,6 +759,7 @@ goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
|
||||
@@ -2,6 +2,7 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -184,12 +185,47 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
|
||||
if networkMap.VNCAuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
|
||||
response.NetworkMap.VncAuth = &proto.VNCAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
response.NetworkMap.VncAuth = &proto.VNCAuth{
|
||||
AuthorizedUsers: hashedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: buildSessionPubKeysProto(ctx, networkMap.VNCSessionPubKeys),
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// buildSessionPubKeysProto decodes base64 X25519 session pubkeys and
|
||||
// hashes the user IDs they belong to, emitting the proto entries the
|
||||
// daemon's authorizer indexes by pubkey.
|
||||
func buildSessionPubKeysProto(ctx context.Context, in []types.VNCSessionPubKey) []*proto.SessionPubKey {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SessionPubKey, 0, len(in))
|
||||
for _, e := range in {
|
||||
pub, err := base64.StdEncoding.DecodeString(e.PubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("decode VNC session pubkey: %v", err)
|
||||
continue
|
||||
}
|
||||
if len(pub) != 32 {
|
||||
log.WithContext(ctx).Warnf("VNC session pubkey wrong length: %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash, err := sshauth.HashUserID(e.UserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("hash VNC session user id: %v", err)
|
||||
continue
|
||||
}
|
||||
out = append(out, &proto.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIdHash: hash[:],
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
|
||||
@@ -517,6 +517,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
if protocol == types.PolicyRuleProtocolNetbirdSSH || protocol == types.PolicyRuleProtocolNetbirdVNC {
|
||||
policy.Rules[0].AuthorizedUser = userAuth.UserId
|
||||
}
|
||||
if protocol == types.PolicyRuleProtocolNetbirdVNC && req.SessionPubKey != nil {
|
||||
policy.Rules[0].SessionPubKey = *req.SessionPubKey
|
||||
}
|
||||
|
||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||
if err != nil {
|
||||
@@ -526,9 +529,10 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
resp := &api.PeerTemporaryAccessResponse{
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Rules: req.Rules,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Rules: req.Rules,
|
||||
TargetPubKey: targetPeer.Key,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
|
||||
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 8)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 1)
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 7)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||
peers, firewallRules, _, _, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
@@ -849,7 +849,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, map[string]map[string]struct{}, bool) {
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, map[string]map[string]struct{}, []VNCSessionPubKey, bool) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
ctxState := &peerConnResolveState{
|
||||
authorizedUsers: make(map[string]map[string]struct{}),
|
||||
@@ -869,7 +869,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
}
|
||||
|
||||
peers, fwRules := getAccumulatedResources()
|
||||
return peers, fwRules, ctxState.authorizedUsers, ctxState.vncAuthorizedUsers, ctxState.sshEnabled
|
||||
return peers, fwRules, ctxState.authorizedUsers, ctxState.vncAuthorizedUsers, ctxState.vncSessionPubKeys, ctxState.sshEnabled
|
||||
}
|
||||
|
||||
func (a *Account) applyPolicyRule(
|
||||
|
||||
@@ -49,6 +49,7 @@ type NetworkMap struct {
|
||||
ForwardingRules []*ForwardingRule
|
||||
AuthorizedUsers map[string]map[string]struct{}
|
||||
VNCAuthorizedUsers map[string]map[string]struct{}
|
||||
VNCSessionPubKeys []VNCSessionPubKey
|
||||
EnableSSH bool
|
||||
}
|
||||
|
||||
|
||||
@@ -167,6 +167,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
|
||||
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
|
||||
AuthorizedUsers: connRes.authorizedUsers,
|
||||
VNCAuthorizedUsers: connRes.vncAuthorizedUsers,
|
||||
VNCSessionPubKeys: connRes.vncSessionPubKeys,
|
||||
EnableSSH: connRes.sshEnabled,
|
||||
}
|
||||
}
|
||||
@@ -177,6 +178,7 @@ type peerConnectionResult struct {
|
||||
firewallRules []*FirewallRule
|
||||
authorizedUsers map[string]map[string]struct{}
|
||||
vncAuthorizedUsers map[string]map[string]struct{}
|
||||
vncSessionPubKeys []VNCSessionPubKey
|
||||
sshEnabled bool
|
||||
}
|
||||
|
||||
@@ -210,6 +212,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) p
|
||||
firewallRules: fwRules,
|
||||
authorizedUsers: state.authorizedUsers,
|
||||
vncAuthorizedUsers: state.vncAuthorizedUsers,
|
||||
vncSessionPubKeys: state.vncSessionPubKeys,
|
||||
sshEnabled: state.sshEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,9 +15,21 @@ import (
|
||||
type peerConnResolveState struct {
|
||||
authorizedUsers map[string]map[string]struct{}
|
||||
vncAuthorizedUsers map[string]map[string]struct{}
|
||||
vncSessionPubKeys []VNCSessionPubKey
|
||||
sshEnabled bool
|
||||
}
|
||||
|
||||
// VNCSessionPubKey carries an ephemeral X25519 static public key the
|
||||
// dashboard registered via temporary-access. The daemon uses it as the
|
||||
// allowed-client side of a Noise_IK handshake; a successful handshake
|
||||
// authenticates the connection as UserID.
|
||||
type VNCSessionPubKey struct {
|
||||
// PubKey is the base64-encoded 32-byte X25519 public key.
|
||||
PubKey string
|
||||
// UserID is the unhashed user identity the pubkey authenticates as.
|
||||
UserID string
|
||||
}
|
||||
|
||||
// ruleAuthCallbacks lets Account and NetworkMapComponents share the per-rule
|
||||
// direction-and-auth logic while keeping their own context/state plumbing for
|
||||
// authorized-user collection and allowed-user lookups.
|
||||
@@ -57,6 +69,12 @@ func applyResolvedRuleToState(
|
||||
return
|
||||
}
|
||||
cb.collectVNCUsers(rule, state.vncAuthorizedUsers)
|
||||
if rule.SessionPubKey != "" && rule.AuthorizedUser != "" {
|
||||
state.vncSessionPubKeys = append(state.vncSessionPubKeys, VNCSessionPubKey{
|
||||
PubKey: rule.SessionPubKey,
|
||||
UserID: rule.AuthorizedUser,
|
||||
})
|
||||
}
|
||||
case policyRuleImpliesLegacySSH(rule) && targetPeerSSHEnabled:
|
||||
if !peerInDestinations {
|
||||
return
|
||||
|
||||
@@ -88,6 +88,12 @@ type PolicyRule struct {
|
||||
|
||||
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||
AuthorizedUser string
|
||||
|
||||
// SessionPubKey is the base64 Ed25519 public key the AuthorizedUser
|
||||
// will sign session-binding challenges with. Set together with
|
||||
// AuthorizedUser when the rule was created via temporary-access for
|
||||
// a VNC scope; empty otherwise.
|
||||
SessionPubKey string
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
@@ -109,6 +115,7 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||
AuthorizedUser: pm.AuthorizedUser,
|
||||
SessionPubKey: pm.SessionPubKey,
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
@@ -136,7 +143,8 @@ func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||
pm.Protocol != other.Protocol ||
|
||||
pm.SourceResource != other.SourceResource ||
|
||||
pm.DestinationResource != other.DestinationResource ||
|
||||
pm.AuthorizedUser != other.AuthorizedUser {
|
||||
pm.AuthorizedUser != other.AuthorizedUser ||
|
||||
pm.SessionPubKey != other.SessionPubKey {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// ErrTokenExpired signals that the iat-based token age check failed. Callers
|
||||
// use errors.Is to branch on it when they want to surface a stable machine-
|
||||
// readable reason (e.g. so a dashboard can prompt for re-login).
|
||||
var ErrTokenExpired = errors.New("token expired")
|
||||
|
||||
// CheckTokenAge validates that a JWT token's iat claim is within the given
|
||||
// maxAge duration. Returns an error if the claims are unparsable, the iat
|
||||
// claim is missing, or the token is too old.
|
||||
func CheckTokenAge(token *gojwt.Token, maxAge time.Duration) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", UserIDFromToken(token))
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", UserIDFromToken(token))
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
if tokenAge > maxAge {
|
||||
return fmt.Errorf("%w for user=%s: age=%v, max=%v", ErrTokenExpired, userIDFromClaims(claims), tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserIDFromToken extracts a human-readable user identifier from a JWT token
|
||||
// for use in error messages. Returns "unknown" if the token or claims are nil.
|
||||
func UserIDFromToken(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return userIDFromClaims(claims)
|
||||
}
|
||||
|
||||
// userIDFromClaims extracts a user identifier from JWT claims, trying sub,
|
||||
// user_id, and email in order.
|
||||
func userIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
@@ -1007,6 +1007,10 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "tcp/80"
|
||||
session_pub_key:
|
||||
description: Ephemeral Ed25519 public key the requester will sign session-binding challenges with. Required for VNC rules; ignored for SSH and L4.
|
||||
type: string
|
||||
example: "n0r3pL4c3h0ld3rK3y=="
|
||||
required:
|
||||
- name
|
||||
- wg_pub_key
|
||||
@@ -1028,10 +1032,15 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "tcp/80"
|
||||
target_pub_key:
|
||||
description: Identity public key of the destination peer the temporary access was requested for. Used by the requester to verify the destination daemon's identity before transmitting credentials.
|
||||
type: string
|
||||
example: "n0r3pL4c3h0ld3rK3y=="
|
||||
required:
|
||||
- name
|
||||
- id
|
||||
- rules
|
||||
- target_pub_key
|
||||
AccessiblePeer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerMinimum'
|
||||
|
||||
@@ -3391,6 +3391,9 @@ type PeerTemporaryAccessRequest struct {
|
||||
// Rules List of temporary access rules
|
||||
Rules []string `json:"rules"`
|
||||
|
||||
// SessionPubKey Ephemeral Ed25519 public key the requester will sign session-binding challenges with. Required for VNC rules; ignored for SSH and L4.
|
||||
SessionPubKey *string `json:"session_pub_key,omitempty"`
|
||||
|
||||
// WgPubKey Peer's WireGuard public key
|
||||
WgPubKey string `json:"wg_pub_key"`
|
||||
}
|
||||
@@ -3405,6 +3408,9 @@ type PeerTemporaryAccessResponse struct {
|
||||
|
||||
// Rules List of temporary access rules
|
||||
Rules []string `json:"rules"`
|
||||
|
||||
// TargetPubKey Identity public key of the destination peer the temporary access was requested for. Used by the requester to verify the destination daemon's identity before transmitting credentials.
|
||||
TargetPubKey string `json:"target_pub_key"`
|
||||
}
|
||||
|
||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -428,9 +428,6 @@ message MachineUserIndexes {
|
||||
|
||||
// VNCAuth represents VNC authorization configuration for a peer.
|
||||
message VNCAuth {
|
||||
// UserIDClaim is the JWT claim to be used to get the users ID
|
||||
string UserIDClaim = 1;
|
||||
|
||||
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via VNC
|
||||
repeated bytes AuthorizedUsers = 2;
|
||||
|
||||
@@ -438,6 +435,24 @@ message VNCAuth {
|
||||
// Used in session mode to determine which OS user to create the virtual session as.
|
||||
// The wildcard "*" allows any OS user.
|
||||
map<string, MachineUserIndexes> machine_users = 3;
|
||||
|
||||
// SessionPubKeys are short-lived X25519 static keypairs the dashboard
|
||||
// (or other temporary-access clients) registers per session. The
|
||||
// daemon runs a Noise_IK handshake against the matching pubkey to
|
||||
// authenticate the connection and resolve the pubkey back to a user.
|
||||
repeated SessionPubKey session_pub_keys = 4;
|
||||
}
|
||||
|
||||
// SessionPubKey binds an ephemeral X25519 static public key to a hashed
|
||||
// user identity so the daemon can authorize VNC connections that
|
||||
// complete a Noise_IK handshake with the matching private key.
|
||||
message SessionPubKey {
|
||||
// PubKey is the 32-byte X25519 static public key.
|
||||
bytes pub_key = 1;
|
||||
|
||||
// UserIDHash is the BLAKE2b-128 hash of the user ID this session
|
||||
// belongs to, matching the entries in VNCAuth.AuthorizedUsers.
|
||||
bytes user_id_hash = 2;
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
|
||||
Reference in New Issue
Block a user