mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 13:09:55 +00:00
Fold init-only VNC and SSH setters into Config-struct constructors
This commit is contained in:
@@ -54,9 +54,12 @@ var vncAgentCmd = &cobra.Command{
|
||||
// The per-user agent listens only on loopback and is gated by an
|
||||
// agent token shared with the daemon, so no X25519 identity key
|
||||
// is needed; auth is disabled at the RFB layer.
|
||||
srv := vncserver.New(capturer, injector, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(token)
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: token,
|
||||
})
|
||||
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), vncAgentPort)
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
|
||||
@@ -237,22 +237,18 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
NetworkValidation: wgAddr,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
|
||||
@@ -99,9 +99,9 @@ func (e *Engine) startVNCServer() error {
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, e.config.WgPrivateKey[:])
|
||||
var sessionRecorder func(vncserver.SessionTick)
|
||||
if e.clientMetrics != nil {
|
||||
srv.SetSessionRecorder(func(t vncserver.SessionTick) {
|
||||
sessionRecorder = func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
@@ -112,16 +112,20 @@ func (e *Engine) startVNCServer() error {
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
if vncNeedsServiceMode() {
|
||||
serviceMode := vncNeedsServiceMode()
|
||||
if serviceMode {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
IdentityKey: e.config.WgPrivateKey[:],
|
||||
ServiceMode: serviceMode,
|
||||
SessionRecorder: sessionRecorder,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
})
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
|
||||
@@ -197,6 +197,14 @@ type Config struct {
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
|
||||
// NetstackNet, when non-nil, makes the SSH server listen via the
|
||||
// supplied userspace network stack instead of an OS socket.
|
||||
NetstackNet *netstack.Net
|
||||
|
||||
// NetworkValidation, when non-zero, restricts inbound connections to
|
||||
// peers inside the NetBird overlay defined by this WireGuard address.
|
||||
NetworkValidation wgaddr.Address
|
||||
}
|
||||
|
||||
// SessionInfo contains information about an active SSH session
|
||||
@@ -208,12 +216,15 @@ type SessionInfo struct {
|
||||
PortForwards []string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
// New creates an SSH server instance from the supplied Config. Fields are
|
||||
// read once at construction; mutating Config afterwards has no effect.
|
||||
// JWT == nil disables JWT authentication.
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
netstackNet: config.NetstackNet,
|
||||
wgAddress: config.NetworkValidation,
|
||||
sessions: make(map[sessionKey]*sessionState),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[forwardKey]net.Listener),
|
||||
@@ -434,20 +445,6 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||
// This should be called when network map updates include new SSH auth configuration
|
||||
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||
|
||||
@@ -28,8 +28,11 @@ func noiseTestServer(t *testing.T) (net.Addr, *Server, []byte) {
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, kp.Private)
|
||||
srv.SetDisableAuth(false)
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
IdentityKey: kp.Private,
|
||||
})
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
@@ -338,8 +341,7 @@ func TestNoise_RevokedKey_RejectedAfterAuthUpdate(t *testing.T) {
|
||||
// without a static private key still rejects authenticated connections
|
||||
// fail-closed; it must not silently accept the client.
|
||||
func TestNoise_NoIdentityKey_FailsClosed(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(false)
|
||||
srv := New(Config{Capturer: &testCapturer{}, Injector: &StubInputInjector{}})
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
@@ -384,7 +386,7 @@ func TestNoise_DerivedIdentityPublicMatchesPrivate(t *testing.T) {
|
||||
for i := range priv {
|
||||
priv[i] = byte(i + 1)
|
||||
}
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, priv)
|
||||
srv := New(Config{Capturer: &testCapturer{}, Injector: &StubInputInjector{}, IdentityKey: priv})
|
||||
|
||||
expected, err := curve25519.X25519(priv, curve25519.Basepoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -254,30 +254,57 @@ type virtualSessionManager interface {
|
||||
StopAll()
|
||||
}
|
||||
|
||||
// New creates a VNC server. identityKey is the 32-byte X25519 private
|
||||
// key used by the daemon in the Noise_IK handshake; nil disables auth.
|
||||
// The protocol-level VNC password scheme is not supported.
|
||||
func New(capturer ScreenCapturer, injector InputInjector, identityKey []byte) *Server {
|
||||
// Config bundles the values the VNC server needs at construction time.
|
||||
// Fields are read once by New; mutating them afterwards has no effect.
|
||||
// Optional fields are nil/zero when unused. The hex-encoded AgentTokenHex
|
||||
// is decoded internally and an invalid value is logged and treated as
|
||||
// empty, matching the legacy SetAgentToken behavior.
|
||||
type Config struct {
|
||||
Capturer ScreenCapturer
|
||||
Injector InputInjector
|
||||
IdentityKey []byte
|
||||
ServiceMode bool
|
||||
SessionRecorder func(SessionTick)
|
||||
DisableAuth bool
|
||||
AgentTokenHex string
|
||||
NetstackNet *netstack.Net
|
||||
}
|
||||
|
||||
// New creates a VNC server from the provided Config. IdentityKey is the
|
||||
// 32-byte X25519 private key used in the Noise_IK handshake; nil disables
|
||||
// auth. The protocol-level VNC password scheme is not supported.
|
||||
func New(cfg Config) *Server {
|
||||
s := &Server{
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
identityKey: identityKey,
|
||||
authorizer: sshauth.NewAuthorizer(),
|
||||
log: log.WithField("component", "vnc-server"),
|
||||
sessions: make(map[uint64]ActiveSessionInfo),
|
||||
sessionConns: make(map[uint64]net.Conn),
|
||||
acceptedConns: make(map[net.Conn]struct{}),
|
||||
connAuth: make(map[net.Conn]connAuthInfo),
|
||||
connSem: make(chan struct{}, maxConcurrentVNCConns),
|
||||
capturer: cfg.Capturer,
|
||||
injector: cfg.Injector,
|
||||
identityKey: cfg.IdentityKey,
|
||||
serviceMode: cfg.ServiceMode,
|
||||
sessionRecorder: cfg.SessionRecorder,
|
||||
disableAuth: cfg.DisableAuth,
|
||||
netstackNet: cfg.NetstackNet,
|
||||
authorizer: sshauth.NewAuthorizer(),
|
||||
log: log.WithField("component", "vnc-server"),
|
||||
sessions: make(map[uint64]ActiveSessionInfo),
|
||||
sessionConns: make(map[uint64]net.Conn),
|
||||
acceptedConns: make(map[net.Conn]struct{}),
|
||||
connAuth: make(map[net.Conn]connAuthInfo),
|
||||
connSem: make(chan struct{}, maxConcurrentVNCConns),
|
||||
}
|
||||
if len(identityKey) == 32 {
|
||||
pub, err := curve25519.X25519(identityKey, curve25519.Basepoint)
|
||||
if len(cfg.IdentityKey) == 32 {
|
||||
pub, err := curve25519.X25519(cfg.IdentityKey, curve25519.Basepoint)
|
||||
if err == nil {
|
||||
s.identityPublic = pub
|
||||
} else {
|
||||
s.log.Warnf("derive identity public key: %v", err)
|
||||
}
|
||||
}
|
||||
if cfg.AgentTokenHex != "" {
|
||||
if b, err := hex.DecodeString(cfg.AgentTokenHex); err == nil {
|
||||
s.agentToken = b
|
||||
} else {
|
||||
s.log.Warnf("invalid agent token: %v", err)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -407,47 +434,6 @@ func (s *Server) revokeUnauthorizedSessions() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
|
||||
func (s *Server) SetServiceMode(enabled bool) {
|
||||
s.serviceMode = enabled
|
||||
}
|
||||
|
||||
// SetSessionRecorder installs a callback that receives a SessionTick
|
||||
// each sessionTickInterval during a VNC session and one final tick on
|
||||
// session close. Pass nil to disable. Empty ticks (no wire activity)
|
||||
// are skipped.
|
||||
func (s *Server) SetSessionRecorder(recorder func(SessionTick)) {
|
||||
s.sessionRecorder = recorder
|
||||
}
|
||||
|
||||
// SetDisableAuth disables authentication entirely.
|
||||
func (s *Server) SetDisableAuth(disable bool) {
|
||||
s.disableAuth = disable
|
||||
}
|
||||
|
||||
// SetAgentToken sets a hex-encoded token that must be presented by incoming
|
||||
// connections before any VNC data. Used in agent mode to verify that only the
|
||||
// trusted service process connects.
|
||||
func (s *Server) SetAgentToken(hexToken string) {
|
||||
if hexToken == "" {
|
||||
return
|
||||
}
|
||||
b, err := hex.DecodeString(hexToken)
|
||||
if err != nil {
|
||||
s.log.Warnf("invalid agent token: %v", err)
|
||||
return
|
||||
}
|
||||
s.agentToken = b
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace-only listening.
|
||||
// When set, the VNC server listens via netstack instead of a real OS socket.
|
||||
func (s *Server) SetNetstackNet(n *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = n
|
||||
}
|
||||
|
||||
// UpdateVNCAuth updates the fine-grained authorization configuration and
|
||||
// closes any live session whose identity no longer authenticates under
|
||||
// the new policy. Revocation is event-driven: there is no periodic
|
||||
|
||||
@@ -28,8 +28,11 @@ func (t *testCapturer) Capture() (*image.RGBA, error) {
|
||||
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
DisableAuth: disableAuth,
|
||||
})
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
@@ -112,8 +115,11 @@ func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
// server must close immediately and the client must see EOF before any RFB
|
||||
// version greeting is written.
|
||||
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
DisableAuth: true,
|
||||
})
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
||||
// the loopback short-circuit in isAllowedSource doesn't apply.
|
||||
@@ -193,7 +199,7 @@ func TestIsAllowedSource(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv := New(Config{Capturer: &testCapturer{}, Injector: &StubInputInjector{}})
|
||||
srv.localAddr = tc.localAddr
|
||||
srv.network = tc.network
|
||||
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
||||
@@ -202,7 +208,7 @@ func TestIsAllowedSource(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv := New(Config{Capturer: &testCapturer{}, Injector: &StubInputInjector{}})
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
||||
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
||||
@@ -210,9 +216,12 @@ func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken("deadbeefcafebabe")
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: "deadbeefcafebabe",
|
||||
})
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
@@ -238,10 +247,13 @@ func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
const tokenHex = "deadbeefcafebabe"
|
||||
srv.SetAgentToken(tokenHex)
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: tokenHex,
|
||||
})
|
||||
token, err := hex.DecodeString(tokenHex)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -275,8 +287,11 @@ func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
||||
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
||||
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv := New(Config{
|
||||
Capturer: &testCapturer{},
|
||||
Injector: &StubInputInjector{},
|
||||
DisableAuth: true,
|
||||
})
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
|
||||
Reference in New Issue
Block a user