From 4e03f708a4cf9461f149b0227cbefae9abbad29e Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 17:39:02 +0300 Subject: [PATCH 001/118] fix dns forwarder port update (#4613) fix dns forwarder port update (#4613) --- client/internal/dnsfwd/manager.go | 16 +++++++--------- client/internal/engine.go | 8 ++++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a3a4ba40f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -40,7 +40,6 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } func ListenPort() uint16 { @@ -49,11 +48,16 @@ func ListenPort() uint16 { return listenPort } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { +func SetListenPort(port uint16) { + listenPortMu.Lock() + listenPort = port + listenPortMu.Unlock() +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, } } @@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() - } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..bebf04f6c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder( return } + if forwarderPort > 0 { + dnsfwd.SetListenPort(forwarderPort) + } + if !enabled { if e.dnsForwardMgr == nil { return @@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder( if len(fwdEntries) > 0 { switch { case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil @@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { log.Errorf("failed to stop DNS forward: %v", err) } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil From d35a845dbd336206f2b306384f1faca15cd4bd03 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 22:18:00 +0300 Subject: [PATCH 002/118] [management] sync all other peers on peer add/remove (#4614) --- management/server/peer.go | 27 ++------------------------- management/server/peer_test.go | 4 ++-- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 4cf5d1e46..218edf299 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && userID != activity.SystemInitiator { + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) - } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } @@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 42b3244ae..fd795b926 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() From bedd3cabc994d6bb9903a95fc6753331c987c25c Mon Sep 17 00:00:00 2001 From: Kostya Leschenko Date: Fri, 10 Oct 2025 16:24:24 +0300 Subject: [PATCH 003/118] [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) --- client/internal/dns/systemd_linux.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string From 5151f19d29514436f65fe40faebf4c89da06133d Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:15:51 +0200 Subject: [PATCH 004/118] [management] pass temporary flag to validator (#4599) --- go.mod | 2 +- go.sum | 4 ++-- management/server/integrated_validator.go | 2 +- .../server/integrations/integrated_validator/interface.go | 4 ++-- management/server/peer.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index a1560b409..31b45e881 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 13838b82d..6b0b298a7 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..251c04273 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..be05c2527 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,16 +3,16 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error diff --git a/management/server/peer.go b/management/server/peer.go index 218edf299..30b7073ef 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -578,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { From 0d2e67983ad2de4fc301514deff3409879819520 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 10 Oct 2025 14:16:48 -0300 Subject: [PATCH 005/118] [misc] Add service definition for netbird-signal (#4620) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 1 + 1 file changed, 1 insertion(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 196e26a66..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -49,6 +49,7 @@ services: - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c From 000e99e7f3f5c18741ee3412ba113c32c6ce3aa7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 13 Oct 2025 17:50:16 +0200 Subject: [PATCH 006/118] [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) --- client/wasm/internal/rdp/cert_validation.go | 15 ++- client/wasm/internal/rdp/rdcleanpath.go | 93 ++++++++++++-- .../wasm/internal/rdp/rdcleanpath_handlers.go | 119 +++++++++--------- 3 files changed, 152 insertions(+), 75 deletions(-) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) From bb37dc89cef3d8338277503ff5e0bbfecb76ffca Mon Sep 17 00:00:00 2001 From: John Conley <8932043+jfrconley@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:46:29 -0700 Subject: [PATCH 007/118] [management] feat: Basic PocketID IDP integration (#4529) --- management/server/idp/auth0_test.go | 8 +- management/server/idp/idp.go | 6 + management/server/idp/pocketid.go | 384 +++++++++++++++++++++++++ management/server/idp/pocketid_test.go | 138 +++++++++ 4 files changed, 533 insertions(+), 3 deletions(-) create mode 100644 management/server/idp/pocketid.go create mode 100644 management/server/idp/pocketid_test.go diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..49075a0d3 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,138 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} From 277aa2b7cc82368d6ec0c6f73acfdfa676ad2c9b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:13:41 +0200 Subject: [PATCH 008/118] [client] Fix missing flag values in profiles (#4650) --- client/server/server.go | 7 + client/server/setconfig_test.go | 298 ++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 client/server/setconfig_test.go diff --git a/client/server/server.go b/client/server/server.go index e6de608c5..89f50a1ef 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..1260bcc78 --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,298 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} From 8252ff41db5fed928a1666e84061ee30d19addf9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:58:29 +0200 Subject: [PATCH 009/118] [client] Add bind activity listener to bypass udp sockets (#4646) --- client/iface/device.go | 1 + client/iface/device/device_android.go | 5 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_ios.go | 5 + client/iface/device/device_kernel_unix.go | 5 + client/iface/device/device_netstack.go | 6 + client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/device/endpoint_manager.go | 13 + client/iface/device_android.go | 1 + client/iface/iface.go | 11 + .../internal/lazyconn/activity/lazy_conn.go | 82 +++++ .../lazyconn/activity/listener_bind.go | 127 ++++++++ .../lazyconn/activity/listener_bind_test.go | 291 ++++++++++++++++++ .../lazyconn/activity/listener_test.go | 41 --- .../activity/{listener.go => listener_udp.go} | 35 ++- .../lazyconn/activity/listener_udp_test.go | 110 +++++++ client/internal/lazyconn/activity/manager.go | 45 ++- .../lazyconn/activity/manager_test.go | 38 ++- client/internal/lazyconn/wgiface.go | 2 + 20 files changed, 760 insertions(+), 73 deletions(-) create mode 100644 client/iface/device/endpoint_manager.go create mode 100644 client/internal/lazyconn/activity/lazy_conn.go create mode 100644 client/internal/lazyconn/activity/listener_bind.go create mode 100644 client/internal/lazyconn/activity/listener_bind_test.go delete mode 100644 client/internal/lazyconn/activity/listener_test.go rename client/internal/lazyconn/activity/{listener.go => listener_udp.go} (64%) create mode 100644 client/internal/lazyconn/activity/listener_udp_test.go diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..48346fc0f 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..cdfcea48d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 158672160..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } From 3abae0bd170c95b79378ab1d59ded7fb83ca985e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:16:51 +0200 Subject: [PATCH 010/118] [client] Set default wg port for new profiles (#4651) --- client/internal/profilemanager/config.go | 1 + client/internal/profilemanager/config_test.go | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..f03822089 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..90bde7707 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string From af95aabb0375f51872a4ff69dd12f412d716955e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 16 Oct 2025 17:15:39 +0200 Subject: [PATCH 011/118] Handle the case when the service has already been down and the status recorder is not available (#4652) --- client/internal/debug/wgshow.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err From 3cdb10cde765621086d15e97da150fee884356f6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:09:39 +0200 Subject: [PATCH 012/118] [client] Remove rule squashing (#4653) --- client/firewall/iptables/acl_linux.go | 1 - client/internal/acl/manager.go | 157 +-------- client/internal/acl/manager_test.go | 486 -------------------------- 3 files changed, 3 insertions(+), 641 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string From 429d7d65855bec44481e572c83e1429ad0c50ebd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:10:16 +0200 Subject: [PATCH 013/118] [client] Support BROWSER env for login (#4654) --- client/cmd/login.go | 11 ++++++++++- client/ui/client_ui.go | 7 +++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7c2000a9d..0043f228e 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) + err := openURL(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) return err @@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": From f5301230bfef5d8b3abaf60634646793fa7b63ac Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 13:31:15 +0200 Subject: [PATCH 014/118] [client] Fix status showing P2P without connection (#4661) --- client/status/status.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..5e4fcd8dc 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -205,15 +205,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { From 0f9bfeff7cf36c81086c8dc91136c9194dbaf819 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 17 Oct 2025 14:47:11 -0300 Subject: [PATCH 015/118] [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 --- client/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ From cd9a867ad02034570daa415fce6fe1dd47c1ccb8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 17 Oct 2025 19:48:26 +0200 Subject: [PATCH 016/118] [client] Delete TURNConfig section from script (#4639) --- infrastructure_files/getting-started-with-zitadel.sh | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", From 2fe2af38d29084f9bf226a83c8aecf7440c633bd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:14:39 +0200 Subject: [PATCH 017/118] [client] Clean up match domain reg entries between config changes (#4676) --- client/internal/dns/host_windows.go | 12 ++- client/internal/dns/host_windows_test.go | 102 ++++++++++++++++++ .../internal/winregistry/volatile_windows.go | 59 ++++++++++ 3 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 client/internal/dns/host_windows_test.go create mode 100644 client/internal/winregistry/volatile_windows.go diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..74111d335 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -197,6 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,9 +209,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } @@ -273,9 +275,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} From 96f71ff1e15b50eff87a7052432537dad2718b20 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 21 Oct 2025 19:23:11 +0200 Subject: [PATCH 018/118] [misc] Update tag name extraction in install.sh (#4677) --- release_files/install.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" From d80d47a469ab7d1d740d28ec836d65dc7776beec Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 22 Oct 2025 12:46:22 +0300 Subject: [PATCH 019/118] [management] Add peer disapproval reason (#4468) --- go.mod | 2 +- go.sum | 4 +- management/server/account/manager.go | 2 +- .../http/handlers/peers/peers_handler.go | 38 ++++++++++++------- management/server/integrated_validator.go | 24 +++++++++--- .../integrated_validator/interface.go | 1 + management/server/mock_server/account_mock.go | 6 +-- shared/management/http/api/openapi.yml | 3 ++ shared/management/http/api/types.gen.go | 6 +++ 9 files changed, 61 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 31b45e881..79dd92e6b 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 6b0b298a7..f0065e081 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..fe9fb25c6 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..e87043f26 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` From 6654e2dbf7c3666a432d7ff457eb3cc5dfdc59aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:07:52 +0200 Subject: [PATCH 020/118] [client] Fix active profile name in debug bundle (#4689) --- client/cmd/debug.go | 8 +++++++- client/internal/debug/debug.go | 4 +++- client/ui/debug.go | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..d53c5f06b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } From 709e24eb6f7b4beff4da86617bbcf518434b8764 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 24 Oct 2025 15:40:20 +0300 Subject: [PATCH 021/118] [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644) This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS. --- infrastructure_files/configure.sh | 5 ++++- infrastructure_files/docker-compose.yml.tmpl | 8 +++++++ signal/cmd/run.go | 23 ++++++++++---------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { From b9ef214ea54fc524b9685f883e288f65720fbd32 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:35:32 +0100 Subject: [PATCH 022/118] [client] Fix macOS state-based dns cleanup (#4701) --- client/internal/dns/host_darwin.go | 41 ++++--- client/internal/dns/host_darwin_test.go | 111 ++++++++++++++++++ client/internal/dns/host_windows.go | 26 ++-- .../internal/dns/unclean_shutdown_darwin.go | 5 + 4 files changed, 151 insertions(+), 32 deletions(-) create mode 100644 client/internal/dns/host_darwin_test.go diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 74111d335..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } From eddea145216f822e6b2842d968b95df351f62533 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:54:00 +0100 Subject: [PATCH 023/118] [client] Clean up bsd routes independently of the state file (#4688) --- client/internal/routemanager/manager.go | 2 +- .../routemanager/systemops/flush_nonbsd.go | 8 ++ .../internal/routemanager/systemops/state.go | 8 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_bsd_test.go | 4 +- .../systemops/systemops_generic_test.go | 6 +- .../routemanager/systemops/systemops_unix.go | 78 ++++++++++++++++++- client/internal/statemanager/manager.go | 2 +- client/server/state.go | 9 +++ 9 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 client/internal/routemanager/systemops/flush_nonbsd.go diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..d590dba0d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -106,7 +106,7 @@ type DefaultManager struct { func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } From 7f08983207a8a39e4610e4b9945cf06b44cb293d Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 27 Oct 2025 22:16:17 +0300 Subject: [PATCH 024/118] Include expired and routing peers in DNS record filtering (#4708) --- management/server/types/account.go | 8 ++++-- management/server/types/account_test.go | 34 ++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..50bdc6ab3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) From 4545ab9a529fd34dc05790866a38141ce9ba7165 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:59:35 +0100 Subject: [PATCH 025/118] [management] rewire account manager to permissions manager (#4673) --- go.mod | 2 +- go.sum | 4 ++-- management/internals/server/modules.go | 8 +++++++- .../http/testing/testing_tools/channel/channel.go | 3 ++- management/server/permissions/manager.go | 6 ++++++ management/server/permissions/manager_mock.go | 13 +++++++++++++ 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 79dd92e6b..06dc921f1 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index f0065e081..ce68ed99e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() From 404cab90ba9c340b848e7f7a457812a20910ed50 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 15:12:53 +0100 Subject: [PATCH 026/118] [client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707) - Port dnat changes from https://github.com/netbirdio/netbird/pull/4015 (nftables/iptables/userspace) - For userspace: rewrite the original port to the target port - Remember original destination port in conntrack - Rewrite the source port back to the original port for replies - Redirect incoming port 5353 to 22054 (tcp/udp) - Revert port changes based on the network map received from management - Adjust tracer to show NAT stages --- client/firewall/iptables/manager_linux.go | 16 + client/firewall/iptables/router_linux.go | 48 +++ client/firewall/manager/firewall.go | 10 +- client/firewall/nftables/manager_linux.go | 16 + client/firewall/nftables/router_linux.go | 97 ++++++ client/firewall/uspfilter/conntrack/common.go | 2 + client/firewall/uspfilter/conntrack/tcp.go | 48 ++- .../firewall/uspfilter/conntrack/tcp_test.go | 8 +- client/firewall/uspfilter/conntrack/udp.go | 34 +- client/firewall/uspfilter/filter.go | 45 ++- client/firewall/uspfilter/log/log.go | 34 +- client/firewall/uspfilter/nat.go | 305 ++++++++++++++---- client/firewall/uspfilter/nat_bench_test.go | 124 +++++++ client/firewall/uspfilter/nat_test.go | 109 +++++++ client/firewall/uspfilter/tracer.go | 233 ++++++++++++- client/firewall/uspfilter/tracer_test.go | 30 ++ client/internal/dnsfwd/manager.go | 69 ++-- client/internal/engine.go | 40 +-- client/internal/netflow/logger/logger.go | 5 +- .../routemanager/dnsinterceptor/handler.go | 4 +- dns/dns.go | 4 + management/server/dns.go | 12 +- management/server/dns_test.go | 24 +- management/server/peer_test.go | 2 +- shared/management/proto/management.proto | 2 +- 25 files changed, 1125 insertions(+), 196 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..16b50211e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..80aea7cf8 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..7ee33118b 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -151,14 +151,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..aa90d3b9a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..648a6aedf 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..fbc39b740 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -109,6 +109,10 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex } // decoder for packages @@ -122,6 +126,8 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor @@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, } m.routingEnabled.Store(false) @@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d726474cf 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..2a285484c 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/device" ) @@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..ee1bb8a23 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..a1c0dff98 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -12,18 +14,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex -) - const ( - dnsTTL = 60 //seconds + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -36,28 +34,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + localAddr netip.Addr + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + localAddr: localAddr, + serverPort: serverPort, } } @@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index bebf04f6c..8f75c0646 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,9 +202,6 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup - - // dns forwarder port - dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -247,7 +244,6 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -1084,7 +1080,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1843,16 +1839,11 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { if e.dnsForwardMgr == nil { return @@ -1864,20 +1855,17 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) + if e.dnsForwardMgr == nil { + localAddr := e.wgInterface.Address().IP + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - default: + log.Infof("started domain router service with %d entries", len(fwdEntries)) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1887,20 +1875,6 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } - -} - -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..a8e697626 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -18,9 +18,9 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..40586f24d 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..e5166ce47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..96f73a390 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index fd795b926..3b2ab87fc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad82d37d9..3982ea2af 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,7 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone From d7321c130b56bc831bfdd8cbec5ff211d37e7de4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Oct 2025 16:11:35 +0100 Subject: [PATCH 027/118] [client] The status cmd will not be blocked by the ICE probe (#4597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The status cmd will not be blocked by the ICE probe Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state. --- client/cmd/debug.go | 4 +- client/cmd/status.go | 6 +- client/internal/engine.go | 20 ++-- client/internal/relay/relay.go | 211 +++++++++++++++++++++++++++++---- client/server/server.go | 9 +- client/status/status.go | 11 +- 6 files changed, 216 insertions(+), 45 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index d53c5f06b..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 8f75c0646..48a23f4ad 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,6 +202,8 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup + + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -244,6 +246,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -1663,7 +1666,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1695,8 +1698,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1713,13 +1720,6 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - // restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { e.syncMsgMux.Lock() diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/server/server.go b/client/server/server.go index 89f50a1ef..3641e6f92 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1057,10 +1057,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1070,7 +1067,7 @@ func (s *Server) Status( return &statusResponse, nil } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1081,7 +1078,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/status/status.go b/client/status/status.go index 5e4fcd8dc..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { From d3a34adcc948800e7d96b59e2b814348ef5569a0 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:21:40 +0100 Subject: [PATCH 028/118] [client] Fix Connect/Disconnect buttons being enabled or disabled at the same time (#4711) --- client/ui/client_ui.go | 83 +++++++++++++++----------------------- client/ui/event_handler.go | 49 +++++++++++++++++++--- client/ui/profile.go | 18 +++++++-- 3 files changed, 91 insertions(+), 59 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0043f228e..22f18b948 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -296,6 +296,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + + connectCancel context.CancelFunc } type menuHandler struct { @@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -610,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginResp, err := conn.Login(ctx, &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, }) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := openURL(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -697,24 +684,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -1381,7 +1364,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1401,7 +1384,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1409,7 +1392,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1422,7 +1405,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e9b7f4f30..e0b619411 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } @@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error { } h.client.getSrvConfig() - + return nil } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } From 1ee575befe351fefb400d70fe07d4c7c58742d35 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:58:43 +0100 Subject: [PATCH 029/118] [client] Use management-provided dns forwarder port on the client side (#4712) --- client/internal/engine.go | 12 +++++++++++- client/internal/routemanager/common/params.go | 2 ++ .../internal/routemanager/dnsinterceptor/handler.go | 6 ++++-- client/internal/routemanager/manager.go | 11 +++++++++++ client/internal/routemanager/mock.go | 4 ++++ dns/dns.go | 2 ++ 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 48a23f4ad..19d37eee1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1059,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1207,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index a8e697626..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -20,7 +21,6 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" - pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d590dba0d..37974cd17 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -101,6 +104,7 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { @@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() @@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/dns/dns.go b/dns/dns.go index 40586f24d..cf089d4ed 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -35,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server From c530db145564a907eef67937b4b1e33eaa3a3ae7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 29 Oct 2025 17:27:18 +0100 Subject: [PATCH 030/118] [client] Fix UI panic when switching profiles (#4718) --- client/ui/client_ui.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 22f18b948..f4350b251 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -833,6 +833,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, From 43c9a519131ef69c86c74a8d6a579e4f2c1d236d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:14:27 +0100 Subject: [PATCH 031/118] [client] Migrate deprecated grpc client code (#4687) --- client/grpc/dialer.go | 42 ++++++++++++++++++++++++++------ client/grpc/dialer_generic.go | 3 +-- client/net/conn.go | 12 +++++---- client/net/dial.go | 1 - client/net/dialer_dial.go | 3 ++- client/net/listener_listen.go | 4 +-- shared/management/client/grpc.go | 3 +-- shared/signal/client/grpc.go | 3 +-- 8 files changed, 49 insertions(+), 22 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } From 86eff0d75047cd17a450778439f0112e4817a35e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:18:09 +0100 Subject: [PATCH 032/118] [client] Fix netstack dns forwarder (#4727) --- client/firewall/uspfilter/filter.go | 97 +++++++++++++++++++++++- client/internal/dnsfwd/cache_test.go | 1 - client/internal/dnsfwd/forwarder.go | 66 ++++++++++++---- client/internal/dnsfwd/forwarder_test.go | 20 ++--- client/internal/dnsfwd/manager.go | 37 ++++++--- client/internal/engine.go | 70 ++++++++++++----- 6 files changed, 231 insertions(+), 60 deletions(-) diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index fbc39b740..ec2d2c57f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -50,6 +50,12 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule @@ -113,6 +119,9 @@ type Manager struct { portDNATEnabled atomic.Bool portDNATRules []portDNATRule portDNATMutex sync.RWMutex + + netstackServices map[serviceKey]struct{} + netstackServiceMutex sync.RWMutex } // decoder for packages @@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, + netstackServices: make(map[serviceKey]struct{}), } m.routingEnabled.Store(false) @@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet return true } - // If requested we pass local traffic to internal interfaces to the forwarder. - // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. - if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { + if m.shouldForward(d, dstIP) { return m.handleForwardedLocalTraffic(packetData) } @@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error { return nil } + +// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port +func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + m.netstackServices[key] = struct{}{} + m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType) + m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices)) +} + +// UnregisterNetstackService removes a service from the netstack registry +func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + delete(m.netstackServices, key) + m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port) +} + +// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use +func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { + switch protocol { + case nftypes.TCP: + return layers.LayerTypeTCP + case nftypes.UDP: + return layers.LayerTypeUDP + case nftypes.ICMP: + return layers.LayerTypeICMPv4 + default: + return gopacket.LayerType(0) // Invalid/unknown + } +} + +// shouldForward determines if a packet should be forwarded to the forwarder. +// The forwarder handles routing packets to the native OS network stack. +// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly. +func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { + // not enabled, never forward + if !m.localForwarding { + return false + } + + // netstack always needs to forward because it's lacking a native interface + // exception for registered netstack services, those should go to netstack listeners + if m.netstack { + return !m.hasMatchingNetstackService(d) + } + + // traffic to our other local interfaces (not NetBird IP) - always forward + if dstIP != m.wgIface.Address().IP { + return true + } + + // traffic to our NetBird IP, not netstack mode - send to netstack listeners + return false +} + +// hasMatchingNetstackService checks if there's a registered netstack service for this packet +func (m *Manager) hasMatchingNetstackService(d *decoder) bool { + if len(d.decoded) < 2 { + return false + } + + var dstPort uint16 + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + dstPort = uint16(d.udp.DstPort) + default: + return false + } + + key := serviceKey{protocol: d.decoded[1], port: dstPort} + m.netstackServiceMutex.RLock() + _, exists := m.netstackServices[key] + m.netstackServiceMutex.RUnlock() + + return exists +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go index c23f0f31d..44ebe290b 100644 --- a/client/internal/dnsfwd/cache_test.go +++ b/client/internal/dnsfwd/cache_test.go @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) { t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) } } - diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c..aef16a8cf 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1..4d0b96a75 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a1c0dff98..b26836d17 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -10,9 +10,11 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -24,6 +26,12 @@ const ( envServerPort = "NB_DNS_FORWARDER_PORT" ) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} + // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { Domain domain.Domain @@ -34,7 +42,7 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status - localAddr netip.Addr + wgIface wgIface serverPort uint16 fwRules []firewall.Rule @@ -42,7 +50,7 @@ type Manager struct { dnsForwarder *DNSForwarder } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { serverPort := nbdns.ForwarderServerPort if envPort := os.Getenv(envServerPort); envPort != "" { if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { @@ -56,7 +64,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti return &Manager{ firewall: fw, statusRecorder: statusRecorder, - localAddr: localAddr, + wgIface: wgIface, serverPort: serverPort, } } @@ -71,21 +79,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS UDP DNAT rule: %v", err) } else { - log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS TCP DNAT rule: %v", err) } else { - log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -111,12 +123,13 @@ func (m *Manager) Stop(ctx context.Context) error { var mErr *multierror.Error - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) } - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 19d37eee1..ad69bcf43 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1855,35 +1855,69 @@ func (e *Engine) updateDNSForwarder( } if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { - localAddr := e.wgInterface.Address().IP - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) - - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - - log.Infof("started domain router service with %d entries", len(fwdEntries)) + e.startDNSForwarder(fwdEntries) } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() + } +} + +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + e.registerDNSServices() + + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return + } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.unregisterDNSServices() + e.dnsForwardMgr = nil +} + +func (e *Engine) registerDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } + } +} + +func (e *Engine) unregisterDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } } } From 8c108ccad34d1c6aba6a009bd1820fe4ff6d71b5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 31 Oct 2025 19:19:14 +0100 Subject: [PATCH 033/118] [client] Extend Darwin network monitoring with wakeup detection --- .../networkmonitor/check_change_bsd.go | 77 +-------- .../networkmonitor/check_change_common.go | 92 +++++++++++ .../networkmonitor/check_change_darwin.go | 149 ++++++++++++++++++ client/internal/networkmonitor/monitor.go | 1 + 4 files changed, 246 insertions(+), 73 deletions(-) create mode 100644 client/internal/networkmonitor/check_change_common.go create mode 100644 client/internal/networkmonitor/check_change_darwin.go diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index f5eb2c739..b3482f54e 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -1,4 +1,4 @@ -//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package networkmonitor @@ -6,21 +6,19 @@ import ( "context" "errors" "fmt" - "syscall" - "unsafe" log "github.com/sirupsen/logrus" - "golang.org/x/net/route" "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + fd, err := prepareFd() if err != nil { return fmt.Errorf("open routing socket: %v", err) } + defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { @@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } }() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) - if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) - continue - } - - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - - switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } - } - } - } - } -} - -func parseRouteMessage(buf []byte) (*systemops.Route, error) { - msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.RouteMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return systemops.MsgToRoute(msg) + return routeCheck(ctx, fd, nexthopv4, nexthopv6) } diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go new file mode 100644 index 000000000..c287236e8 --- /dev/null +++ b/client/internal/networkmonitor/check_change_common.go @@ -0,0 +1,92 @@ +//go:build dragonfly || freebsd || netbsd || openbsd || darwin + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func prepareFd() (int, error) { + return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) +} + +func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Warnf("Network monitor: failed to read from routing socket: %v", err) + } + continue + } + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Debugf("Network monitor: error parsing routing message: %v", err) + continue + } + + if route.Dst.Bits() != 0 { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + return nil + } + } + } + } + } +} + +func parseRouteMessage(buf []byte) (*systemops.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return systemops.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go new file mode 100644 index 000000000..ddc6e1736 --- /dev/null +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -0,0 +1,149 @@ +//go:build darwin && !ios + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// todo: refactor to not use static functions + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + fd, err := prepareFd() + if err != nil { + return fmt.Errorf("open routing socket: %v", err) + } + + defer func() { + if err := unix.Close(fd); err != nil { + if !errors.Is(err, unix.EBADF) { + log.Warnf("Network monitor: failed to close routing socket: %v", err) + } + } + }() + + routeChanged := make(chan struct{}) + go func() { + _ = routeCheck(ctx, fd, nexthopv4, nexthopv6) + close(routeChanged) + }() + + wakeUp := make(chan struct{}) + go func() { + wakeUpListen(ctx) + close(wakeUp) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-routeChanged: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("route change detected") + return nil + case <-wakeUp: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("wakeup detected") + return nil + } +} + +func wakeUpListen(ctx context.Context) { + log.Infof("start to watch for system wakeups") + var ( + initialHash uint32 + err error + ) + + // Keep retrying until initial sysctl succeeds or context is canceled + for { + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + default: + initialHash, err = readSleepTimeHash() + if err != nil { + log.Errorf("failed to detect initial sleep time: %v", err) + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + case <-time.After(3 * time.Second): + continue + } + } + log.Debugf("initial wakeup hash: %d", initialHash) + break + } + break + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("context canceled, stopping wakeUpListen") + return + + case <-ticker.C: + newHash, err := readSleepTimeHash() + if err != nil { + log.Errorf("failed to read sleep time hash: %v", err) + continue + } + + if newHash == initialHash { + log.Tracef("no wakeup detected") + continue + } + + upOut, err := exec.Command("uptime").Output() + if err != nil { + log.Errorf("failed to run uptime command: %v", err) + upOut = []byte("unknown") + } + log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut) + return + } + } +} + +func readSleepTimeHash() (uint32, error) { + cmd := exec.Command("sysctl", "kern.sleeptime") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("failed to run sysctl: %w", err) + } + + h, err := hash(out) + if err != nil { + return 0, fmt.Errorf("failed to compute hash: %w", err) + } + + return h, nil +} + +func hash(data []byte) (uint32, error) { + hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher + if _, err := hasher.Write(data); err != nil { + return 0, err + } + return hasher.Sum32(), nil +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index accdd9c9d..6d019258d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { event := make(chan struct{}, 1) go nw.checkChanges(ctx, event, nexthop4, nexthop6) + log.Infof("start watching for network changes") // debounce changes timer := time.NewTimer(0) timer.Stop() From a2313a5ba4959a5550c242947a4faf3f9b0cde9b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 15:27:22 +0100 Subject: [PATCH 034/118] [client] Bump github.com/quic-go/quic-go from 0.48.2 to 0.49.1 (#4621) Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.48.2 to 0.49.1. - [Release notes](https://github.com/quic-go/quic-go/releases) - [Commits](https://github.com/quic-go/quic-go/compare/v0.48.2...v0.49.1) --- updated-dependencies: - dependency-name: github.com/quic-go/quic-go dependency-version: 0.49.1 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 06dc921f1..98b158411 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 - github.com/quic-go/quic-go v0.48.2 + github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 @@ -241,7 +241,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.4.0 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/go.sum b/go.sum index ce68ed99e..9705f3aed 100644 --- a/go.sum +++ b/go.sum @@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= From 719283c792a46311daa334254c5f90a39e0afa00 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:40:12 +0100 Subject: [PATCH 035/118] [management] update db connection lifecycle configuration (#4740) --- management/server/activity/store/sql_store.go | 23 ++++++++++++------- management/server/store/sql_store.go | 6 ++++- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b165938..ffecb6b8f 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strconv" + "time" log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" @@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e return nil, err } - if storeEngine == types.SqliteStoreEngine { - sqlDB.SetMaxOpenConns(1) - } else { - conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) - if err != nil { - conns = runtime.NumCPU() - } - sqlDB.SetMaxOpenConns(conns) + conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) + if err != nil { + conns = runtime.NumCPU() } + if storeEngine == types.SqliteStoreEngine { + conns = 1 + } + + sqlDB.SetMaxOpenConns(conns) + sqlDB.SetMaxIdleConns(conns) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(3 * time.Minute) + + log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) return db, nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..4201b68f6 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -89,8 +89,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met } sql.SetMaxOpenConns(conns) + sql.SetMaxIdleConns(conns) + sql.SetConnMaxLifetime(time.Hour) + sql.SetConnMaxIdleTime(3 * time.Minute) - log.WithContext(ctx).Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) if skipMigration { log.WithContext(ctx).Infof("skipping migration") From 679c58ce472d8763d8bea987662104fa6a381dc2 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:06:35 +0100 Subject: [PATCH 036/118] [client] Set up networkd to ignore ip rules (#4730) --- client/cmd/service_installer.go | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 075ead44e..2a87e538d 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "runtime" + log "github.com/sirupsen/logrus" + "github.com/kardianos/service" "github.com/spf13/cobra" @@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { svcConfig.Option["LogDirectory"] = dir } } + + if err := configureSystemdNetworkd(); err != nil { + log.Warnf("failed to configure systemd-networkd: %v", err) + } } if runtime.GOOS == "windows" { @@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{ return fmt.Errorf("uninstall service: %w", err) } + if runtime.GOOS == "linux" { + if err := cleanupSystemdNetworkd(); err != nil { + log.Warnf("failed to cleanup systemd-networkd configuration: %v", err) + } + } + cmd.Println("NetBird service has been uninstalled") return nil }, @@ -245,3 +257,45 @@ func isServiceRunning() (bool, error) { return status == service.StatusRunning, nil } + +const ( + networkdConfDir = "/etc/systemd/networkd.conf.d" + networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" + networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing +# routes and policy rules managed by NetBird. + +[Network] +ManageForeignRoutes=no +ManageForeignRoutingPolicyRules=no +` +) + +// configureSystemdNetworkd creates a drop-in configuration file to prevent +// systemd-networkd from removing NetBird's routes and policy rules. +func configureSystemdNetworkd() error { + parentDir := filepath.Dir(networkdConfDir) + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration") + return nil + } + + // nolint:gosec // standard networkd permissions + if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { + return fmt.Errorf("write networkd configuration: %w", err) + } + + return nil +} + +// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file. +func cleanupSystemdNetworkd() error { + if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) { + return nil + } + + if err := os.Remove(networkdConfFile); err != nil { + return fmt.Errorf("remove networkd configuration: %w", err) + } + + return nil +} From 45c25dca84150c0da85b06875facb01f0bcd119a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:18:51 +0100 Subject: [PATCH 037/118] [client] Clamp MSS on outbound traffic (#4735) --- client/firewall/create.go | 4 +- client/firewall/create_linux.go | 22 +- client/firewall/iptables/manager_linux.go | 5 +- .../firewall/iptables/manager_linux_test.go | 9 +- client/firewall/iptables/router_linux.go | 50 +++- client/firewall/iptables/router_linux_test.go | 14 +- client/firewall/iptables/state_linux.go | 8 +- client/firewall/nftables/manager_linux.go | 5 +- .../firewall/nftables/manager_linux_test.go | 9 +- client/firewall/nftables/router_linux.go | 98 ++++++- client/firewall/nftables/router_linux_test.go | 9 +- client/firewall/nftables/state_linux.go | 8 +- client/firewall/uspfilter/filter.go | 154 ++++++++-- .../firewall/uspfilter/filter_bench_test.go | 277 +++++++++++++++--- .../firewall/uspfilter/filter_filter_test.go | 7 +- client/firewall/uspfilter/filter_test.go | 215 +++++++++++++- .../firewall/uspfilter/forwarder/forwarder.go | 8 +- client/firewall/uspfilter/forwarder/udp.go | 2 +- client/firewall/uspfilter/nat_bench_test.go | 11 +- client/firewall/uspfilter/nat_test.go | 9 +- client/firewall/uspfilter/tracer_test.go | 3 +- client/internal/acl/manager_test.go | 7 +- client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 2 +- 24 files changed, 804 insertions(+), 134 deletions(-) diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..24f12bc6d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,13 +15,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..12dcaee8a 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) if !iface.IsUserspaceBind() { return fm, err @@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { - fm, err := createFW(iface) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { + fm, err := createFW(iface, mtu) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) } @@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - return nbiptables.Create(iface) + return nbiptables.Create(iface, mtu) case NFTABLES: log.Info("creating an nftables firewall manager") - return nbnftables.Create(iface) + return nbnftables.Create(iface, mtu) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") return nil, errors.New("no firewall manager found") } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) } if errUsp != nil { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 16b50211e..2563a9052 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index a5cc62feb..6b5401e2b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 80aea7cf8..305b0bf28 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -30,17 +30,20 @@ const ( chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" + chainFORWARD = "FORWARD" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" matchSet = "--match-set" @@ -48,6 +51,9 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" fwdSuffix = "_fwd" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) type ruleInfo struct { @@ -77,16 +83,18 @@ type router struct { ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + mtu uint16 stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } -func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + mtu: mtu, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -416,6 +425,7 @@ func (r *router) createContainers() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) @@ -438,6 +448,10 @@ func (r *router) createContainers() error { return fmt.Errorf("add jump rules: %w", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + return nil } @@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + // Add jump rule from FORWARD chain in mangle table to our custom chain + jumpRule := []string{ + "-j", chainRTMSSCLAMP, + } + if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil { + return fmt.Errorf("add jump to MSS clamp chain: %w", err) + } + r.rules[jumpMSSClamp] = jumpRule + + ruleOut := []string{ + "-o", r.wgIface.Name(), + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mss), + } + if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil { + return fmt.Errorf("add outbound MSS clamp rule: %w", err) + } + r.rules["mss-clamp-out"] = ruleOut + + return nil +} + func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -558,7 +601,7 @@ func (r *router) addJumpRules() error { } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} { if rule, exists := r.rules[ruleKey]; exists { var table, chain string switch ruleKey { @@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error { case jumpNatPre: table = tableNat chain = chainPREROUTING + case jumpMSSClamp: + table = tableMangle + chain = chainFORWARD default: return fmt.Errorf("unknown jump rule: %s", ruleKey) } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..6707573be 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { assert.NoError(t, manager.Reset(), "shouldn't return error") }() - // Now 5 rules: // 1. established rule forward in // 2. estbalished rule forward out // 3. jump rule to POST nat chain @@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 7. static return masquerade rule // 8. mangle prerouting mark rule // 9. mangle postrouting mark rule - require.Len(t, manager.rules, 9, "should have created rules map") + // 10. jump rule to MSS clamping chain + // 11. MSS clamping rule for outbound traffic + require.Len(t, manager.rules, 11, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 6ef159e01..c88774c1f 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -11,6 +12,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + ipt, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index aa90d3b9a..d864914fe 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -44,7 +44,7 @@ type Manager struct { } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, @@ -53,7 +53,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(workTable, wgIface) + m.router, err = newRouter(workTable, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -93,6 +93,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index c7f05dcb7..adec802c8 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) { // This test verifies rule insertion order in nftables peer ACLs // We add accept rule first, then deny rule to test ordering behavior - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr := runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err, "failed to create manager") require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 648a6aedf..0a2c79186 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -16,6 +16,7 @@ import ( "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -32,12 +33,16 @@ const ( chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" chainNameForward = "FORWARD" + chainNameMangleForward = "netbird-mangle-forward" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" dnatSuffix = "_dnat" snatSuffix = "_snat" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) const refreshRulesMapError = "refresh rules map: %w" @@ -63,9 +68,10 @@ type router struct { wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool + mtu uint16 } -func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { +func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ conn: &nftables.Conn{}, workTable: workTable, @@ -73,6 +79,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) rules: make(map[string]*nftables.Rule), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), + mtu: mtu, } r.ipsetCounter = refcounter.New( @@ -220,11 +227,23 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) + r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameMangleForward, + Table: r.workTable, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + // Add the single NAT rule that matches on mark if err := r.addPostroutingRules(); err != nil { return fmt.Errorf("add single nat rule: %v", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -745,6 +764,83 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + exprsOut := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 1, + Mask: []byte{0x02}, + Xor: []byte{0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00}, + }, + &expr.Counter{}, + &expr.Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + &expr.Cmp{ + Op: expr.CmpOpGt, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameMangleForward], + Exprs: exprsOut, + }) + + return nil +} + // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { sourceExp, err := r.applyNetwork(pair.Source, nil, true) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 4fdbf3505..3531b014b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index f805623d6..48b7b3741 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,6 +3,7 @@ package nftables import ( "fmt" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -10,6 +11,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + nft, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index ec2d2c57f..990630ee4 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "errors" "fmt" "net" @@ -27,7 +28,12 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const ( + layerTypeAll = 0 + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 +) const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. @@ -36,6 +42,9 @@ const ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" + // EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic. + EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING" + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" @@ -122,6 +131,10 @@ type Manager struct { netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex + + mtu uint16 + mssClampValue uint16 + mssClampEnabled bool } // decoder for packages @@ -140,16 +153,16 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - return create(iface, nil, disableServerRoutes, flowLogger) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger, mtu) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } @@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func parseCreateEnv() (bool, bool) { - var disableConntrack, enableLocalForwarding bool +func parseCreateEnv() (bool, bool, bool) { + var disableConntrack, enableLocalForwarding, disableMSSClamping bool var err error if val := os.Getenv(EnvDisableConntrack); val != "" { disableConntrack, err = strconv.ParseBool(val) @@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) { log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) } } + if val := os.Getenv(EnvDisableMSSClamping); val != "" { + disableMSSClamping, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err) + } + } - return disableConntrack, enableLocalForwarding + return disableConntrack, enableLocalForwarding, disableMSSClamping } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - disableConntrack, enableLocalForwarding := parseCreateEnv() +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, netstackServices: make(map[serviceKey]struct{}), + mtu: mtu, } m.routingEnabled.Store(false) + if !disableMSSClamping { + m.mssClampEnabled = true + m.mssClampValue = mtu - ipTCPHeaderMinSize + } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) } - if disableConntrack { log.Info("conntrack is disabled") } else { @@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } - - // netstack needs the forwarder for local traffic if m.netstack && m.localForwarding { if err := m.initForwarder(); err != nil { log.Errorf("failed to initialize forwarder: %v", err) } } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } @@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu) if err != nil { m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) @@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return false } - if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { - return true + switch d.decoded[1] { + case layers.LayerTypeUDP: + if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { + return true + } + case layers.LayerTypeTCP: + // Clamp MSS on all TCP SYN packets, including those from local IPs. + // SNATed routed traffic may appear as local IP but still requires clamping. + if m.mssClampEnabled { + m.clampTCPMSS(packetData, d) + } } m.trackOutbound(d, srcIP, dstIP, packetData, size) @@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } +// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation. +// Both sides advertise their MSS during connection establishment, so we need to clamp both. +func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { + if !d.tcp.SYN { + return false + } + if len(d.tcp.Options) == 0 { + return false + } + + mssOptionIndex := -1 + var currentMSS uint16 + for i, opt := range d.tcp.Options { + if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { + currentMSS = binary.BigEndian.Uint16(opt.OptionData) + if currentMSS > m.mssClampValue { + mssOptionIndex = i + break + } + } + } + + if mssOptionIndex == -1 { + return false + } + + ipHeaderSize := int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + + if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + return false + } + + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + return true +} + +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { + tcpHeaderStart := ipHeaderSize + tcpOptionsStart := tcpHeaderStart + 20 + + optOffset := tcpOptionsStart + for j := 0; j < mssOptionIndex; j++ { + switch d.tcp.Options[j].OptionType { + case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: + optOffset++ + default: + optOffset += 2 + len(d.tcp.Options[j].OptionData) + } + } + + mssValueOffset := optOffset + 2 + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + + m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) + return true +} + +func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) { + tcpLayer := packetData[tcpHeaderStart:] + tcpLength := len(packetData) - tcpHeaderStart + + tcpLayer[16] = 0 + tcpLayer[17] = 0 + + var pseudoSum uint32 + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + + var sum uint32 = pseudoSum + for i := 0; i < tcpLength-1; i += 2 { + sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) + } + if tcpLength%2 == 1 { + sum += uint32(tcpLayer[tcpLength-1]) << 8 + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + checksum := ^uint16(sum) + binary.BigEndian.PutUint16(tcpLayer[16:18], checksum) +} + func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..5a2d0410f 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -17,6 +17,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } -// generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { - b.Helper() - - ipv4 := &layers.IPv4{ - TTL: 64, - Version: 4, - SrcIP: srcIP, - DstIP: dstIP, - Protocol: layers.IPProtocolTCP, - } - - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - - // Set TCP flags - tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 - tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 - tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 - tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 - tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - - require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) - return buf.Bytes() -} - func BenchmarkRouteACLs(b *testing.B) { manager := setupRoutedManager(b, "10.10.0.100/16") @@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) { } } } + +// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound. +// This shows the overhead difference between the common case (non-SYN packets, fast path) +// and the rare case (SYN packets that need clamping, expensive path). +func BenchmarkMSSClamping(b *testing.B) { + scenarios := []struct { + name string + description string + genPacket func(*testing.B, net.IP, net.IP) []byte + frequency string + }{ + { + name: "syn_needs_clamp", + description: "SYN packet needing MSS clamping", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + frequency: "~0.1% of traffic - EXPENSIVE", + }, + { + name: "syn_no_clamp_needed", + description: "SYN packet with already-small MSS", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200) + }, + frequency: "~0.05% of traffic", + }, + { + name: "tcp_ack", + description: "Non-SYN TCP packet (ACK, data transfer)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + frequency: "~60-70% of traffic - FAST PATH", + }, + { + name: "tcp_psh_ack", + description: "TCP data packet (PSH+ACK)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + frequency: "~10-20% of traffic - FAST PATH", + }, + { + name: "udp", + description: "UDP packet", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + frequency: "~20-30% of traffic - FAST PATH", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled +// for the common case (non-SYN TCP packets). +func BenchmarkMSSClampingOverhead(b *testing.B) { + scenarios := []struct { + name string + enabled bool + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "disabled_tcp_ack", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "enabled_tcp_ack", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "disabled_syn_needs_clamp", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "enabled_syn_needs_clamp", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = sc.enabled + if sc.enabled { + manager.mssClampValue = 1240 + } + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases +func BenchmarkMSSClampingMemory(b *testing.B) { + scenarios := []struct { + name string + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "tcp_ack_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "syn_needs_clamp", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "udp_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte { + b.Helper() + + ip := &layers.IPv4{ + Version: 4, + IHL: 5, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Seq: 1000, + Window: 65535, + } + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ip)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{}))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 73f3face8..eb5aa3343 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -12,6 +12,7 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) require.NotNil(t, manager) @@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(tb, err) require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) @@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index bac06814d..c56a078fc 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -17,6 +18,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nbiface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" @@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false, flowLogger) + manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() @@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) time.Sleep(time.Second) @@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() // Close the existing tracker @@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) { require.Equal(t, tc.expected, isAllowed, tc.desc) } } + +func TestMSSClamping(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, 1280) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") + expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) + require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + srcIP := net.ParseIP("100.10.0.2") + dstIP := net.ParseIP("8.8.8.8") + + t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + }) + + t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { + lowMSS := uint16(1200) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified") + }) + + t.Run("SYN-ACK packet gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + }) + + t.Run("Non-SYN packet unchanged", func(t *testing.T) { + packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck)) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Empty(t, d.tcp.Options, "ACK packet should have no options") + }) +} + +func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Window: 65535, + } + + if flags&uint16(conntrack.TCPSyn) != 0 { + tcpLayer.SYN = true + } + if flags&uint16(conntrack.TCPAck) != 0 { + tcpLayer.ACK = true + } + if flags&uint16(conntrack.TCPFin) != 0 { + tcpLayer.FIN = true + } + if flags&uint16(conntrack.TCPRst) != 0 { + tcpLayer.RST = true + } + if flags&uint16(conntrack.TCPPush) != 0 { + tcpLayer.PSH = true + } + + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 42a3e0800..00cb3f1df 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -45,7 +45,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow HandleLocal: false, }) - mtu, err := iface.GetDevice().MTU() - if err != nil { - return nil, fmt.Errorf("get MTU: %w", err) - } nicID := tcpip.NICID(1) endpoint := &endpoint{ logger: logger, @@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } if err := s.CreateNIC(nicID, endpoint); err != nil { - return nil, fmt.Errorf("failed to create NIC: %v", err) + return nil, fmt.Errorf("create NIC: %v", err) } protoAddr := tcpip.ProtocolAddress{ diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d146de5e4..55743d975 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -49,7 +49,7 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { +func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ logger: logger, diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index d726474cf..d2599e577 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 2a285484c..400d61020 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -16,7 +17,7 @@ import ( func TestDNATTranslationCorrectness(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { func TestDNATMappingManagement(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) { func TestInboundPortDNAT(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) { func TestInboundPortDNATNegative(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index ee1bb8a23..d9f9f1aa8 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -10,6 +10,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) if !statefulMode { diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index daf4979ce..638245bf7 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" @@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..451b83f92 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false, flowLogger) + pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index ad69bcf43..0c7bd9f0a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -506,7 +506,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil From 641eb5140bf51a5b6f2ca1b3a391662bffb86e32 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 21:56:53 +0100 Subject: [PATCH 038/118] [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) --- client/firewall/manager/firewall.go | 3 + client/firewall/nftables/acl_linux.go | 41 +++---- client/firewall/nftables/manager_linux.go | 141 ++++------------------ client/firewall/nftables/router_linux.go | 94 +++++++++++---- client/internal/dnsfwd/manager.go | 42 ++++++- client/internal/engine.go | 35 +----- 6 files changed, 146 insertions(+), 210 deletions(-) diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 7ee33118b..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d864914fe..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0a2c79186..6192c92aa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -37,6 +37,7 @@ const ( userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" @@ -103,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -118,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -936,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") @@ -945,7 +947,7 @@ func (r *router) acceptForwardRules() error { fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available @@ -955,22 +957,30 @@ func (r *router) acceptForwardRules() error { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables() } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -982,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +func (r *router) acceptFilterRulesNftables() error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ @@ -1018,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error { }, } - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -1031,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error { Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) + inputRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameInput, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) + return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) removeAcceptFilterRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + return r.removeAcceptFilterRulesNftables() } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { +func (r *router) removeAcceptFilterRulesNftables() error { chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != r.filterTable.Name { + continue + } + + if chain.Name != chainNameForward && chain.Name != chainNameInput { continue } @@ -1070,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error { for _, rule := range rules { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("delete rule: %v", err) } @@ -1085,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error { return nil } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index b26836d17..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -15,6 +15,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error { } } + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -158,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index 0c7bd9f0a..3c7d52cb3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -298,17 +298,12 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } + e.stopDNSForwarder() + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } @@ -1873,7 +1868,6 @@ func (e *Engine) updateDNSForwarder( func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) - e.registerDNSServices() if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) @@ -1893,34 +1887,9 @@ func (e *Engine) stopDNSForwarder() { log.Errorf("failed to stop DNS forward: %v", err) } - e.unregisterDNSServices() e.dnsForwardMgr = nil } -func (e *Engine) registerDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - RegisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - -func (e *Engine) unregisterDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - UnregisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - func (e *Engine) GetNet() (*netstack.Net, error) { e.syncMsgMux.Lock() intf := e.wgInterface From c92e6c1b5fd040248b4fd36a2e397daeeb21ba9f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:15:37 +0100 Subject: [PATCH 039/118] [client] Block on all subsystems on shutdown (#4709) --- client/internal/connect.go | 32 +++-- client/internal/dns/server.go | 9 +- client/internal/engine.go | 119 +++++++++++++----- client/internal/netflow/manager.go | 17 ++- client/internal/peer/guard/sr_watcher.go | 9 +- client/internal/routemanager/manager.go | 14 ++- .../internal/routeselector/routeselector.go | 4 - 7 files changed, 139 insertions(+), 65 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index c9331baf5..bb7c2b38b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -34,7 +35,6 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + engine := c.engine + c.engine = nil + c.engineMutex.Unlock() + + if engine != nil && engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..afaf0579f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 3c7d52cb3..ebc05c453 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -148,6 +148,8 @@ type Engine struct { // syncMsgMux is used to guarantee sequential Management Service message processing syncMsgMux *sync.Mutex + // sshMux protects sshServer field access + sshMux sync.Mutex config *EngineConfig mobileDep MobileDependency @@ -200,8 +202,10 @@ type Engine struct { flowManager nftypes.FlowManager // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + wgIfaceMonitor *WGIfaceMonitor + + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup probeStunTurn *relay.StunTurnProbe } @@ -320,10 +324,6 @@ func (e *Engine) Stop() error { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -331,8 +331,6 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -343,12 +341,52 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -478,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -669,9 +707,11 @@ func (e *Engine) removeAllPeers() error { func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) + e.sshMux.Lock() if !isNil(e.sshServer) { e.sshServer.RemoveAuthorizedKey(peerKey) } + e.sshMux.Unlock() e.connMgr.RemovePeerConn(peerKey) @@ -873,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { log.Warnf("running SSH server on %s is not supported", runtime.GOOS) return nil } + e.sshMux.Lock() // start SSH server if it wasn't running if isNil(e.sshServer) { listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) @@ -880,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) } // nil sshServer means it has not yet been started - var err error - e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr) - + server, err := e.sshServerFunc(e.config.SSHKey, listenAddr) if err != nil { + e.sshMux.Unlock() return fmt.Errorf("create ssh server: %w", err) } + + e.sshServer = server + e.sshMux.Unlock() + go func() { // blocking - err = e.sshServer.Start() + err = server.Start() if err != nil { // will throw error when we stop it even if it is a graceful stop log.Debugf("stopped SSH server with error %v", err) } - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + e.sshMux.Lock() e.sshServer = nil + e.sshMux.Unlock() log.Infof("stopped SSH server") }() } else { + e.sshMux.Unlock() log.Debugf("SSH server is already running") } - } else if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) + } else { + e.sshMux.Lock() + if !isNil(e.sshServer) { + // Disable SSH server request, so stop it if it was running + err := e.sshServer.Stop() + if err != nil { + log.Warnf("failed to stop SSH server %v", err) + } + e.sshServer = nil } - e.sshServer = nil + e.sshMux.Unlock() } return nil } @@ -944,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -1120,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() // update SSHServer by adding remote peer SSH keys + e.sshMux.Lock() if !isNil(e.sshServer) { for _, config := range networkMap.GetRemotePeers() { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { @@ -1130,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } + e.sshMux.Unlock() } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1372,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() @@ -1489,12 +1544,14 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } + e.sshMux.Lock() if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { log.Warnf("failed stopping the SSH server: %v", err) } } + e.sshMux.Unlock() if e.firewall != nil { err := e.firewall.Close(e.stateManager) @@ -1725,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool { return allHealthy } -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1748,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1758,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..6f4f5ad4f 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,10 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 37974cd17..26cf758d9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -81,6 +81,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -283,6 +284,7 @@ func (m *DefaultManager) SetDNSForwarderPort(port uint16) { // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -485,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -527,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial, diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..61c8bbc79 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,8 +9,6 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) @@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) return isSelected } From 75327d951941f60229565a38ece91d4ceae97303 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:00:20 +0100 Subject: [PATCH 040/118] [client] Add login_hint to oidc flows (#4724) --- client/android/login.go | 2 +- client/cmd/login.go | 26 +++++++++++++++++++++----- client/cmd/up.go | 9 ++++++++- client/internal/auth/device_flow.go | 25 +++++++++++++++++++++++++ client/internal/auth/oauth.go | 18 +++++++++++------- client/internal/auth/pkce_flow.go | 3 +++ client/internal/device_auth.go | 2 ++ client/internal/pkce_auth.go | 2 ++ client/ios/NetBirdSDK/client.go | 2 +- client/proto/daemon.pb.go | 21 ++++++++++++++++----- client/proto/daemon.proto | 3 +++ client/server/server.go | 6 +++++- client/ui/client_ui.go | 13 +++++++++++-- 13 files changed, 109 insertions(+), 23 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..16df24ba8 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error { } func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") if err != nil { return nil, err } diff --git a/client/cmd/login.go b/client/cmd/login.go index 40b55f858..b0c877faa 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str Username: &username, } + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { loginRequest.OptionalPreSharedKey = &preSharedKey } @@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, return fmt.Errorf("read config file %s: %v", configFilePath, err) } - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo return nil } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { needsLogin := false err := WithBackOff(func() error { @@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken := "" if setupKey == "" && needsLogin { - tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) + tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { + hint := "" + pm := profilemanager.NewProfileManager() + profileState, err := pm.GetProfileState(profileName) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + hint = profileState.Email + } + + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) if err != nil { return nil, err } diff --git a/client/cmd/up.go b/client/cmd/up.go index d047c041e..80175f7be 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) - err = foregroundLogin(ctx, cmd, config, providedSetupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ loginRequest.ProfileName = &activeProf.Name loginRequest.Username = &username + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index da4f16c8d..8ca760742 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow deviceCode.VerificationURIComplete = deviceCode.VerificationURI } + if d.providerConfig.LoginHint != "" { + deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint) + if deviceCode.VerificationURI != "" { + deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint) + } + } + return deviceCode, err } +func appendLoginHint(uri, loginHint string) string { + if uri == "" || loginHint == "" { + return uri + } + + parsedURL, err := url.Parse(uri) + if err != nil { + log.Debugf("failed to parse verification URI for login_hint: %v", err) + return uri + } + + query := parsedURL.Query() + query.Set("login_hint", loginHint) + parsedURL.RawQuery = query.Encode() + + return parsedURL.String() +} + func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", d.providerConfig.ClientID) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 4458f600c..9fbd6cf5f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string { // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } - pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) if err != nil { - // fallback to device code flow log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debug("falling back to device code flow") - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } return pkceFlow, nil } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } + + pkceFlowInfo.ProviderConfig.LoginHint = hint + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { @@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } + deviceFlowInfo.ProviderConfig.LoginHint = hint + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 8741e8636..738d3e34f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn params = append(params, oauth2.SetAuthURLParam("max_age", "0")) } } + if p.providerConfig.LoginHint != "" { + params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint)) + } authURL := p.oAuthConfig.AuthCodeURL(state, params...) diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 6bd29801d..7f7d06130 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct { Scope string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a713bb342..23c92e8af 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct { DisablePromptLogin bool // LoginFlag is used to configure the PKCE flow login behavior LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..fa1c89aab 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") if err != nil { return err.Error() } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 841e3c0f7..02f09b08a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -279,8 +279,10 @@ type LoginRequest struct { ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // hint is used to pre-fill the email/username field during SSO authentication + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 { return 0 } +func (x *LoginRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\xe5\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\x04_mtuB\a\n" + + "\x05_hint\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5b27b4d98..8d1080051 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -158,6 +158,9 @@ message LoginRequest { optional string username = 31; optional int64 mtu = 32; + + // hint is used to pre-fill the email/username field during SSO authentication + optional string hint = 33; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 3641e6f92..6699cdadc 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index f4350b251..e580be56d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -610,11 +610,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(ctx, &proto.LoginRequest{ + loginReq := &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, - }) + } + + profileState, err := s.profileManager.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginReq.Hint = &profileState.Email + } + + loginResp, err := conn.Login(ctx, loginReq) if err != nil { return nil, fmt.Errorf("login to management: %w", err) } From 229e0038ee643ec775d0670c030db70670ae298c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:30:17 +0100 Subject: [PATCH 041/118] [client] Add dns config to debug bundle (#4704) --- client/internal/debug/debug.go | 20 +++++++++ client/internal/debug/debug_darwin.go | 53 ++++++++++++++++++++++++ client/internal/debug/debug_mobile.go | 4 ++ client/internal/debug/debug_nondarwin.go | 16 +++++++ client/internal/debug/debug_nonunix.go | 7 ++++ client/internal/debug/debug_unix.go | 29 +++++++++++++ 6 files changed, 129 insertions(+) create mode 100644 client/internal/debug/debug_darwin.go create mode 100644 client/internal/debug/debug_nondarwin.go create mode 100644 client/internal/debug/debug_nonunix.go create mode 100644 client/internal/debug/debug_unix.go diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 442f54e71..fbec29ce3 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. +scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. @@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information: The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. + +DNS Configuration +The debug bundle includes platform-specific DNS configuration files: + +resolv.conf (Unix systems): +- Contains DNS resolver configuration from /etc/resolv.conf +- Includes nameserver entries, search domains, and resolver options +- All IP addresses and domain names are anonymized following the same rules as other files + +scutil_dns.txt (macOS only): +- Contains detailed DNS configuration from scutil --dns +- Shows DNS configuration for all network interfaces +- Includes search domains, nameservers, and DNS resolver settings +- All IP addresses and domain names are anonymized ` const ( @@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() { if err := g.addFirewallRules(); err != nil { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + + if err := g.addDNSInfo(); err != nil { + log.Errorf("failed to add DNS info to debug bundle: %v", err) + } } func (g *BundleGenerator) addReadme() error { diff --git a/client/internal/debug/debug_darwin.go b/client/internal/debug/debug_darwin.go new file mode 100644 index 000000000..91e10214f --- /dev/null +++ b/client/internal/debug/debug_darwin.go @@ -0,0 +1,53 @@ +//go:build darwin && !ios + +package debug + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + if err := g.addScutilDNS(); err != nil { + log.Errorf("failed to add scutil DNS output: %v", err) + } + + return nil +} + +func (g *BundleGenerator) addScutilDNS() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "scutil", "--dns") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("execute scutil --dns: %w", err) + } + + if len(bytes.TrimSpace(output)) == 0 { + return fmt.Errorf("no scutil DNS output") + } + + content := string(output) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil { + return fmt.Errorf("add scutil DNS output to zip: %w", err) + } + + return nil +} diff --git a/client/internal/debug/debug_mobile.go b/client/internal/debug/debug_mobile.go index c00c65132..3c1745ff3 100644 --- a/client/internal/debug/debug_mobile.go +++ b/client/internal/debug/debug_mobile.go @@ -5,3 +5,7 @@ package debug func (g *BundleGenerator) addRoutes() error { return nil } + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_nondarwin.go b/client/internal/debug/debug_nondarwin.go new file mode 100644 index 000000000..dfc2eace5 --- /dev/null +++ b/client/internal/debug/debug_nondarwin.go @@ -0,0 +1,16 @@ +//go:build unix && !darwin && !android + +package debug + +import ( + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + return nil +} diff --git a/client/internal/debug/debug_nonunix.go b/client/internal/debug/debug_nonunix.go new file mode 100644 index 000000000..18d017050 --- /dev/null +++ b/client/internal/debug/debug_nonunix.go @@ -0,0 +1,7 @@ +//go:build !unix + +package debug + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_unix.go b/client/internal/debug/debug_unix.go new file mode 100644 index 000000000..7e8a74eb0 --- /dev/null +++ b/client/internal/debug/debug_unix.go @@ -0,0 +1,29 @@ +//go:build unix && !android + +package debug + +import ( + "fmt" + "os" + "strings" +) + +const resolvConfPath = "/etc/resolv.conf" + +func (g *BundleGenerator) addResolvConf() error { + data, err := os.ReadFile(resolvConfPath) + if err != nil { + return fmt.Errorf("read %s: %w", resolvConfPath, err) + } + + content := string(data) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil { + return fmt.Errorf("add resolv.conf to zip: %w", err) + } + + return nil +} From 5c29d395b29d3d152117c11e3a375111863a204f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:51:14 +0100 Subject: [PATCH 042/118] [management] activity events on group updates (#4750) --- management/server/group.go | 20 +++++++++++++++----- management/server/user.go | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 487cb6d97..a29c28892 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err != nil { return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) @@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } } - newGroup.AccountID = accountID - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) if err != nil { return err @@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac if err == nil && oldGroup != nil { addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) + + if oldGroup.Name != newGroup.Name { + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "old_name": oldGroup.Name, + "new_name": newGroup.Name, + } + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta) + }) + } } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a..25c87df9c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -595,7 +595,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -621,6 +621,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac }) } + addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err) + } + + for _, group := range addedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + }) + } + + removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err) + } + for _, group := range removedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + }) + } + return eventsToStore } @@ -667,9 +696,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact peersToExpire = userPeers } + var removedGroups, addedGroups []string if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups) + addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups) for _, peer := range userPeers { for _, groupID := range removedGroups { if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { @@ -685,7 +715,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } From 2e16c9914a7ba92fa92f56c66c079e4dd1e0fd34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:01:44 +0300 Subject: [PATCH 043/118] [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756) Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29. - [Release notes](https://github.com/containerd/containerd/releases) - [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md) - [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29) --- updated-dependencies: - dependency-name: github.com/containerd/containerd dependency-version: 1.7.29 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 98b158411..d93d2c064 100644 --- a/go.mod +++ b/go.mod @@ -102,9 +102,9 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.25.0 + golang.org/x/mod v0.26.0 golang.org/x/net v0.42.0 - golang.org/x/oauth2 v0.28.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 google.golang.org/api v0.177.0 @@ -146,7 +146,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/containerd v1.7.27 // indirect + github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect @@ -245,7 +245,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.5.0 // indirect + golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/go.sum b/go.sum index 9705f3aed..61ad8740e 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= -github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= +github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE= +github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= From 6aa4ba7af441c9c07fd76d7946e179ea5257e975 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:44:46 +0100 Subject: [PATCH 044/118] [management] incremental network map builder (#4753) --- go.mod | 2 +- management/main.go | 10 +- management/server/account.go | 36 + management/server/account/manager.go | 1 + management/server/account_test.go | 55 + management/server/dns.go | 3 + management/server/group.go | 24 + management/server/grpcserver.go | 85 +- management/server/holder.go | 39 + management/server/mock_server/account_mock.go | 14 +- management/server/nameserver.go | 9 + management/server/networkmap.go | 80 + management/server/networks/manager.go | 3 + .../server/networks/resources/manager.go | 9 + management/server/networks/routers/manager.go | 9 + management/server/peer.go | 133 +- management/server/peer_test.go | 23 +- management/server/policy.go | 6 + management/server/posture_checks.go | 3 + management/server/route.go | 9 + management/server/settings/manager.go | 8 + management/server/store/sql_store.go | 1265 ++++++++++- .../store/sql_store_get_account_test.go | 1089 ++++++++++ .../server/store/sqlstore_bench_test.go | 951 ++++++++ management/server/store/store.go | 104 +- management/server/types/account.go | 13 + management/server/types/holder.go | 43 + management/server/types/networkmap.go | 58 + .../server/types/networkmap_golden_test.go | 1069 +++++++++ management/server/types/networkmapbuilder.go | 1932 +++++++++++++++++ management/server/updatechannel.go | 6 +- management/server/user.go | 4 + route/route.go | 1 + 33 files changed, 7018 insertions(+), 78 deletions(-) create mode 100644 management/server/holder.go create mode 100644 management/server/networkmap.go create mode 100644 management/server/store/sql_store_get_account_test.go create mode 100644 management/server/store/sqlstore_bench_test.go create mode 100644 management/server/types/holder.go create mode 100644 management/server/types/networkmap.go create mode 100644 management/server/types/networkmap_golden_test.go create mode 100644 management/server/types/networkmapbuilder.go diff --git a/go.mod b/go.mod index d93d2c064..68a12908d 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 @@ -183,7 +184,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/management/main.go b/management/main.go index 561ed8f26..ff8482f97 100644 --- a/management/main.go +++ b/management/main.go @@ -1,11 +1,19 @@ package main import ( - "github.com/netbirdio/netbird/management/cmd" + "log" + "net/http" + // nolint:gosec + _ "net/http/pprof" "os" + + "github.com/netbirdio/netbird/management/cmd" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/management/server/account.go b/management/server/account.go index dca105ddf..0aecbd586 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -53,6 +53,9 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + + envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -109,6 +112,11 @@ type DefaultAccountManager struct { loginFilter *loginFilter disableDefaultPolicy bool + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} } func isUniqueConstraintError(err error) bool { @@ -196,6 +204,18 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + am := &DefaultAccountManager{ Store: store, geo: geo, @@ -217,6 +237,10 @@ func BuildManager( permissionsManager: permissionsManager, loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, + holder: types.NewHolder(), + + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, } am.startWarmup(ctx) @@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } go am.UpdateAccountPeers(ctx, accountID) } @@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { + return err + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -2129,6 +2160,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us } if updateNetworkMap { + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + am.updatePeerInNetworkMapCache(peer.AccountID, peer) am.BufferUpdateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index fe9fb25c6..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -128,4 +128,5 @@ type Manager interface { GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool + RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account_test.go b/management/server/account_test.go index 07d2f2383..200ba6b98 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } +func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SaveGroup(t) +} + func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + testAccountManager_NetworkUpdates_SaveGroup(t) +} + +func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + +func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SavePolicy(t) +} + func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_SavePolicy(t) +} + +func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ @@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePeer(t) +} + func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePeer(t) +} + +func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + +func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } + for drained := false; !drained; { + select { + case <-updMsg: + default: + drained = true + } + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, + NetworkMapCache: &types.NetworkMapBuilder{}, } + account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) diff --git a/management/server/dns.go b/management/server/dns.go index e5166ce47..decc5175d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a29c28892..3cf9290a2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -481,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -519,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -547,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -585,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 12b59b691..0a5236cb3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,8 +7,10 @@ import ( "net" "net/netip" "os" + "strconv" "strings" "sync" + "sync/atomic" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -44,6 +46,9 @@ import ( const ( envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" envBlockPeers = "NB_BLOCK_SAME_PEERS" + envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS" + + defaultSyncLim = 1000 ) // GRPCServer an instance of a Management gRPC API server @@ -63,6 +68,9 @@ type GRPCServer struct { logBlockedPeers bool blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + + syncSem atomic.Int32 + syncLim int32 } // NewServer creates a new Management server @@ -96,6 +104,17 @@ func NewServer( logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + syncLim := int32(defaultSyncLim) + if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { + syncLimParsed, err := strconv.Atoi(syncLimStr) + if err != nil { + log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim) + } else { + //nolint:gosec + syncLim = int32(syncLimParsed) + } + } + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -110,6 +129,8 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + + syncLim: syncLim, }, nil } @@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + if s.syncSem.Load() >= s.syncLim { + return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") + } + s.syncSem.Add(1) + reqStart := time.Now() ctx := srv.Context() @@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { + s.syncSem.Add(-1) return err } realIP := getRealIP(ctx) @@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { + s.syncSem.Add(-1) return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) } } @@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) - unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) - defer func() { - if unlock != nil { - unlock() - } - }() - accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + s.syncSem.Add(-1) return status.Errorf(codes.PermissionDenied, "peer is not registered") } + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + start := time.Now() + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) + log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { @@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return mapError(ctx, err) } + log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) + log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) + s.ephemeralManager.OnPeerConnected(ctx, peer) + log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) + s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) + if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + s.syncSem.Add(-1) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } + took := time.Since(reqStart) + if took > 7*time.Second { + log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) + } }() if loginReq.GetMeta() == nil { @@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, mapError(ctx, err) } + log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) + // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { s.ephemeralManager.OnPeerDisconnected(ctx, peer) + log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) } loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) @@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, status.Errorf(codes.Internal, "failed logging in peer") } + log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -716,6 +775,11 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("toSyncResponse: took %s", time.Since(start)) + }() + response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ @@ -780,6 +844,11 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("sendInitialSync: took %s", time.Since(start)) + }() + var err error var turnToken *Token @@ -822,10 +891,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } + sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: s.wgKey.PublicKey().String(), Body: encryptedResp, }) + log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) diff --git a/management/server/holder.go b/management/server/holder.go new file mode 100644 index 000000000..e8a26e1d0 --- /dev/null +++ b/management/server/holder.go @@ -0,0 +1,39 @@ +package server + +import ( + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { + a := am.holder.GetAccount(account.Id) + if a == nil { + am.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + am.holder.AddAccount(account) +} + +func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { + return am.holder.GetAccount(accountID) +} + +func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { + a := am.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { + am.holder.AddAccount(account) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e87043f26..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -125,9 +125,10 @@ type MockAccountManager struct { UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { } return true } + +func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error { + if am.RecalculateNetworkMapCacheFunc != nil { + return am.RecalculateNetworkMapCacheFunc(ctx, accountID) + } + return nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..ee77a65bb 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/networkmap.go b/management/server/networkmap.go new file mode 100644 index 000000000..2a0627643 --- /dev/null +++ b/management/server/networkmap.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + am.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (am *DefaultAccountManager) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := am.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := am.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + am.updateAccountInHolder(account) +} + +func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if am.experimentalNetworkMap(accountId) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + am.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { + _, ok := am.expNewNetworkMapAIDs[accountId] + return am.expNewNetworkMap || ok +} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..0e6d1631b 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 66484d120..b740610c2 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..89ac419fd 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 30b7073ef..80ab7fc69 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -145,6 +145,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. am.BufferUpdateAccountPeers(ctx, accountID) @@ -321,6 +324,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) } else if sshChanged { @@ -381,6 +388,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + } + + } + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -417,7 +436,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(peer.AccountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -690,6 +715,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + + if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) + } + } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) @@ -776,6 +812,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } am.BufferUpdateAccountPeers(ctx, accountID) } @@ -783,6 +822,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("handlePeerNotFound: took %s", time.Since(start)) + }() if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -804,6 +847,11 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("LoginPeer: took %s", time.Since(start)) + }() + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -831,6 +879,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -900,8 +949,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + startBuffer := time.Now() am.BufferUpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) @@ -909,6 +965,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPostureChecks: took %s", time.Since(start)) + }() + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1014,9 +1075,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } } approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) @@ -1024,10 +1093,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } + startPosture := time.Now() postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) @@ -1037,7 +1108,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1167,11 +1244,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } } globalStart := time.Now() @@ -1204,6 +1288,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + if am.experimentalNetworkMap(accountID) { + am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -1241,7 +1329,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) start = time.Now() @@ -1257,7 +1351,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) }(peer) } @@ -1351,7 +1445,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1368,7 +1468,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1511,6 +1611,10 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPeerGroupIDs: took %s", time.Since(start)) + }() return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } @@ -1580,7 +1684,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }, }, }, - NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 3b2ab87fc..e151f5abb 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) { } func TestAccountManager_GetNetworkMap(t *testing.T) { + testGetNetworkMapGeneral(t) +} + +func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testGetNetworkMapGeneral(t) +} + +func testGetNetworkMapGeneral(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testUpdateAccountPeers(t) +} + func TestUpdateAccountPeers(t *testing.T) { + testUpdateAccountPeers(t) +} + +func testUpdateAccountPeers(t *testing.T) { testCases := []struct { name string peers int @@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel assert.Nil(t, update.Update.NetbirdConfig) - assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) - assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) + assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } }) } @@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } diff --git a/management/server/policy.go b/management/server/policy.go index 9e4b3f73a..ff02d46aa 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..f457b994b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..05f7acf9e 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..f16b609f8 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,9 @@ package settings import ( "context" "fmt" + "time" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) + }() + if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4201b68f6..d83d160c3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -46,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 1 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +63,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -307,6 +316,10 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("SavePeer: took %s", time.Since(start)) + }() // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -778,6 +791,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -788,9 +808,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -800,70 +830,1147 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups sql.NullString + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange sql.NullString + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups sql.NullString + networkNet sql.NullString + dnsSettingsDisabledGroups sql.NullString + networkIdentifier sql.NullString + networkDns sql.NullString + networkSerial sql.NullInt64 + createdAt sql.NullTime + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(err) + } + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if networkNet.Valid { + _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net) + } + if createdAt.Valid { + account.CreatedAt = createdAt.Time + } + if dnsSettingsDisabledGroups.Valid { + _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups) + } + if networkIdentifier.Valid { + account.Network.Identifier = networkIdentifier.String + } + if networkDns.Valid { + account.Network.Dns = networkDns.String + } + if networkSerial.Valid { + account.Network.Serial = uint64(networkSerial.Int64) + } + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups.Valid { + _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) + } + if sNetworkRange.Valid { + _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups.Valid { + _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) + } + account.InitOnce() return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, + revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt, + &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if skCreatedAt.Valid { + sk.CreatedAt = skCreatedAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, + meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ( + lastLogin, createdAt sql.NullTime + sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + peerStatusLastSeen sql.NullTime + peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + ip, extraDNS, netAddr, env, flags, files, connIP []byte + metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString + metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString + metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString + locationCountryCode, locationCityName sql.NullString + locationGeoNameID sql.NullInt64 + ) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, + &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, + &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, + &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, + &locationCountryCode, &locationCityName, &locationGeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + p.CreatedAt = createdAt.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if metaHostname.Valid { + p.Meta.Hostname = metaHostname.String + } + if metaGoOS.Valid { + p.Meta.GoOS = metaGoOS.String + } + if metaKernel.Valid { + p.Meta.Kernel = metaKernel.String + } + if metaCore.Valid { + p.Meta.Core = metaCore.String + } + if metaPlatform.Valid { + p.Meta.Platform = metaPlatform.String + } + if metaOS.Valid { + p.Meta.OS = metaOS.String + } + if metaOSVersion.Valid { + p.Meta.OSVersion = metaOSVersion.String + } + if metaWtVersion.Valid { + p.Meta.WtVersion = metaWtVersion.String + } + if metaUIVersion.Valid { + p.Meta.UIVersion = metaUIVersion.String + } + if metaKernelVersion.Valid { + p.Meta.KernelVersion = metaKernelVersion.String + } + if metaSystemSerialNumber.Valid { + p.Meta.SystemSerialNumber = metaSystemSerialNumber.String + } + if metaSystemProductName.Valid { + p.Meta.SystemProductName = metaSystemProductName.String + } + if metaSystemManufacturer.Valid { + p.Meta.SystemManufacturer = metaSystemManufacturer.String + } + if locationCountryCode.Valid { + p.Location.CountryCode = locationCountryCode.String + } + if locationCityName.Valid { + p.Location.CityName = locationCityName.String + } + if locationGeoNameID.Valid { + p.Location.GeoNameID = uint(locationGeoNameID.Int64) + } + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin, createdAt sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + var createdAt, updatedAt sql.NullTime + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &createdAt, + &updatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if createdAt.Valid { + account.Onboarding.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + account.Onboarding.UpdatedAt = updatedAt.Time + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed, createdAt sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if createdAt.Valid { + pat.CreatedAt = createdAt.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1054,6 +2161,10 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetAccountNetwork: took %s", time.Since(start)) + }() ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -1095,6 +2206,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getAccountSettings: took %s", time.Since(start)) + }() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -1203,8 +2319,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. @@ -1273,7 +2422,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics, false) + store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false) if err != nil { return nil, err } @@ -1293,6 +2442,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// used for tests only +func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) + if err != nil { + return nil, err + } + pool, err := connectToPgDbForTests(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} + +// used for tests only +func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 5 + config.MinConns = 1 + config.MaxConnLifetime = 30 * time.Second + config.HealthCheckPeriod = 10 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + // NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewMysqlStore(ctx, dsn, metrics, false) diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..350a1da83 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,951 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + b.Fatalf("failed to create test container: %v", err) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + err := db.Migrator().DropTable(models[i]) + if err != nil { + b.Fatalf("failed to drop table: %v", err) + } + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, cleanup, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, cleanup, accountID := setupBenchmarkDB(b) + defer cleanup() + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, cleanup, accountID := setupBenchmarkDB(t) + defer cleanup() + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 21b660d96..007e2b739 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) closeConnection := func() { cleanup() store.Close(ctx) + if store.pool != nil { + store.pool.Close() + } } return store, closeConnection, nil @@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + if err != nil { return nil, nil, err } @@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } - db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } + sqlDB, err := db.DB() + if err != nil { + return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB.Close() + if err != nil { return nil, nil, err } @@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return store, cleanup, nil } +func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) { + var db *gorm.DB + var err error + + for i := range maxRetries { + switch engine { + case types.PostgresStoreEngine: + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case types.MysqlStoreEngine: + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + } + + if err == nil { + return db, nil + } + + if i < maxRetries-1 { + waitTime := time.Duration(100*(i+1)) * time.Millisecond + time.Sleep(waitTime) + } + } + + return nil, err +} + func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) @@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func( return "", nil, fmt.Errorf("failed to create database: %v", err) } - var err error + originalDSN := dsn + cleanup := func() { + var dropDB *gorm.DB + var err error + switch engine { case types.PostgresStoreEngine: - err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error + case types.MysqlStoreEngine: - // err = killMySQLConnections(dsn, dbName) - err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error } + if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) - panic(err) } - sqlDB, _ := db.DB() - _ = sqlDB.Close() } return replaceDBName(dsn, dbName), cleanup, nil diff --git a/management/server/types/account.go b/management/server/types/account.go index 50bdc6ab3..dd6052498 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,6 +8,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-multierror" @@ -87,6 +88,13 @@ type Account struct { NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` + + NetworkMapCache *NetworkMapBuilder `gorm:"-"` + nmapInitOnce *sync.Once `gorm:"-"` +} + +func (a *Account) InitOnce() { + a.nmapInitOnce = &sync.Once{} } // this class is used by gorm only @@ -257,6 +265,9 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start)) + }() peer := a.Peers[peerID] if peer == nil { @@ -890,6 +901,8 @@ func (a *Account) Copy() *Account { NetworkRouters: networkRouters, NetworkResources: networkResources, Onboarding: a.Onboarding, + NetworkMapCache: a.NetworkMapCache, + nmapInitOnce: a.nmapInitOnce, } } diff --git a/management/server/types/holder.go b/management/server/types/holder.go new file mode 100644 index 000000000..3996db2b6 --- /dev/null +++ b/management/server/types/holder.go @@ -0,0 +1,43 @@ +package types + +import ( + "context" + "sync" +) + +type Holder struct { + mu sync.RWMutex + accounts map[string]*Account +} + +func NewHolder() *Holder { + return &Holder{ + accounts: make(map[string]*Account), + } +} + +func (h *Holder) GetAccount(id string) *Account { + h.mu.RLock() + defer h.mu.RUnlock() + return h.accounts[id] +} + +func (h *Holder) AddAccount(account *Account) { + h.mu.Lock() + defer h.mu.Unlock() + h.accounts[account.Id] = account +} + +func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { + h.mu.Lock() + defer h.mu.Unlock() + if acc, ok := h.accounts[id]; ok { + return acc, nil + } + account, err := accGetter(context.Background(), id) + if err != nil { + return nil, err + } + h.accounts[id] = account + return account, nil +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go new file mode 100644 index 000000000..c1099726f --- /dev/null +++ b/management/server/types/networkmap.go @@ -0,0 +1,58 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { + if a.NetworkMapCache != nil { + return + } + a.nmapInitOnce.Do(func() { + a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) + }) +} + +func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} + +func (a *Account) GetPeerNetworkMapExp( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + a.initNetworkMapBuilder(validatedPeers) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) +} + +func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerAddedIncremental(peerId) +} + +func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerDeleted(peerId) +} + +func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.UpdatePeer(peer) +} + +func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go new file mode 100644 index 000000000..d85aaabb2 --- /dev/null +++ b/management/server/types/networkmap_golden_test.go @@ -0,0 +1,1069 @@ +package types_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// update flag is used to update the golden file. +// example: go test ./... -v -update +// var update = flag.Bool("update", false, "update golden files") + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. + expiredPeerID = "peer-98" // This peer will be online but with an expired session. + offlinePeerID = "peer-99" // This peer will be completely offline. + routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. + testAccountID = "account-golden-test" +) + +func TestGetPeerNetworkMap_Golden(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") +} + +func BenchmarkGetPeerNetworkMap(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + b.ResetTimer() + b.Run("old builder", func(b *testing.B) { + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + b.ResetTimer() + b.Run("new builder", func(b *testing.B) { + for range b.N { + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + for _, peerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + + t.Log("Update golden file with new peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newPeerID) + require.NoError(t, err, "error adding peer to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") + t.Log("Update golden file with OnPeerAdded...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: net.IP{100, 64, 1, 1}, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + } + + account.Peers[newPeerID] = newPeer + account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) + account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) + validatedPeersMap[newPeerID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + + t.Log("Update golden file with new router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newRouterID) + require.NoError(t, err, "error adding router to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newRouterID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedPeerID := "peer-25" // peer from devs group + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" // devs group peer + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedPeerID) + require.NoError(t, err, "error deleting peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") + t.Log("Update golden file with OnPeerDeleted...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedRouterID) + require.NoError(t, err, "error deleting routing peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + + t.Log("Update golden file with deleted router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err) + + require.JSONEq(t, string(expectedJSON), string(jsonData), + "network map after deleting router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + deletedPeerID := "peer-25" + + delete(account.Peers, deletedPeerID) + account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + delete(validatedPeersMap, deletedPeerID) + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + b.ResetTimer() + b.Run("old builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerDeleted(deletedPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { + for _, peer := range networkMap.Peers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + for _, peer := range networkMap.OfflinePeers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + + sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) + sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) + sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) + + sort.Slice(networkMap.FirewallRules, func(i, j int) bool { + r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] + if r1.PeerIP != r2.PeerIP { + return r1.PeerIP < r2.PeerIP + } + if r1.Protocol != r2.Protocol { + return r1.Protocol < r2.Protocol + } + if r1.Direction != r2.Direction { + return r1.Direction < r2.Direction + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + return r1.Port < r2.Port + }) + + sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { + r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] + if r1.RouteID != r2.RouteID { + return r1.RouteID < r2.RouteID + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + if r1.Destination != r2.Destination { + return r1.Destination < r2.Destination + } + if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { + if r1.SourceRanges[0] != r2.SourceRanges[0] { + return r1.SourceRanges[0] < r2.SourceRanges[0] + } + } + return r1.Port < r2.Port + }) + + for _, ranges := range networkMap.RoutesFirewallRules { + sort.Slice(ranges.SourceRanges, func(i, j int) bool { + return ranges.SourceRanges[i] < ranges.SourceRanges[j] + }) + } +} + +func createTestAccountWithEntities() *types.Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + + } + + groups := map[string]*types.Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*types.Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*types.PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &types.Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &types.Network{ + Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*dns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go new file mode 100644 index 000000000..58f1bfa30 --- /dev/null +++ b/management/server/types/networkmapbuilder.go @@ -0,0 +1,1932 @@ +package types + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +const ( + allPeers = "0.0.0.0" + fw = "fw:" + rfw = "route-fw:" + nr = "network-resource-" +) + +type NetworkMapCache struct { + globalRoutes map[route.ID]*route.Route + globalRules map[string]*FirewallRule //ruleId + globalRouteRules map[string]*RouteFirewallRule //ruleId + globalPeers map[string]*nbpeer.Peer + + groupToPeers map[string][]string + peerToGroups map[string][]string + policyToRules map[string][]*PolicyRule //policyId + groupToPolicies map[string][]*Policy + groupToRoutes map[string][]*route.Route + peerToRoutes map[string][]*route.Route + + peerACLs map[string]*PeerACLView + peerRoutes map[string]*PeerRoutesView + peerDNS map[string]*nbdns.Config + + resourceRouters map[string]map[string]*routerTypes.NetworkRouter + resourcePolicies map[string][]*Policy + + globalResources map[string]*resourceTypes.NetworkResource // resourceId + + acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info + noACGRoutes map[route.ID]*RouteOwnerInfo + + mu sync.RWMutex +} + +type RouteOwnerInfo struct { + PeerID string + RouteID route.ID +} + +type PeerACLView struct { + ConnectedPeerIDs []string + FirewallRuleIDs []string +} + +type PeerRoutesView struct { + OwnRouteIDs []route.ID + NetworkResourceIDs []route.ID + InheritedRouteIDs []route.ID + RouteFirewallRuleIDs []string +} + +type NetworkMapBuilder struct { + account atomic.Pointer[Account] + cache *NetworkMapCache + validatedPeers map[string]struct{} +} + +func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { + builder := &NetworkMapBuilder{ + cache: &NetworkMapCache{ + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), + peerACLs: make(map[string]*PeerACLView), + peerRoutes: make(map[string]*PeerRoutesView), + peerDNS: make(map[string]*nbdns.Config), + globalResources: make(map[string]*resourceTypes.NetworkResource), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + }, + validatedPeers: make(map[string]struct{}), + } + builder.account.Store(account) + maps.Copy(builder.validatedPeers, validatedPeers) + + builder.initialBuild(account) + + return builder +} + +func (b *NetworkMapBuilder) initialBuild(account *Account) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + start := time.Now() + + b.buildGlobalIndexes(account) + + resourceRouters := account.GetResourceRoutersMap() + resourcePolicies := account.GetResourcePoliciesMap() + b.cache.resourceRouters = resourceRouters + b.cache.resourcePolicies = resourcePolicies + + for peerID := range account.Peers { + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + } + + log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) +} + +func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { + clear(b.cache.globalPeers) + clear(b.cache.groupToPeers) + clear(b.cache.peerToGroups) + clear(b.cache.policyToRules) + clear(b.cache.groupToPolicies) + clear(b.cache.globalRoutes) + clear(b.cache.globalRules) + clear(b.cache.globalRouteRules) + clear(b.cache.globalResources) + clear(b.cache.groupToRoutes) + clear(b.cache.peerToRoutes) + clear(b.cache.acgToRoutes) + clear(b.cache.noACGRoutes) + + maps.Copy(b.cache.globalPeers, account.Peers) + + for groupID, group := range account.Groups { + peersCopy := make([]string, len(group.Peers)) + copy(peersCopy, group.Peers) + b.cache.groupToPeers[groupID] = peersCopy + + for _, peerID := range group.Peers { + b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) + } + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + b.cache.policyToRules[policy.ID] = policy.Rules + + affectedGroups := make(map[string]struct{}) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, groupID := range rule.Sources { + affectedGroups[groupID] = struct{}{} + } + for _, groupID := range rule.Destinations { + affectedGroups[groupID] = struct{}{} + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + groupId := rule.SourceResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + groupId := rule.DestinationResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) + } + } + + for groupID := range affectedGroups { + b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) + } + } + + for _, resource := range account.NetworkResources { + if !resource.Enabled { + continue + } + b.cache.globalResources[resource.ID] = resource + } + + for _, r := range account.Routes { + if !r.Enabled { + continue + } + for _, groupID := range r.PeerGroups { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } +} + +func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { + peer := account.GetPeer(peerID) + if peer == nil { + return + } + + allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + + isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) + + var emptyExpiredPeers []*nbpeer.Peer + finalAllPeers := b.addNetworksRoutingPeers( + networkResourcesRoutes, + peer, + allPotentialPeers, + emptyExpiredPeers, + isRouter, + sourcePeers, + ) + + view := &PeerACLView{ + ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), + FirewallRuleIDs: make([]string, 0, len(firewallRules)), + } + + for _, p := range finalAllPeers { + view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) + } + + for _, rule := range firewallRules { + ruleID := b.generateFirewallRuleID(rule) + view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) + b.cache.globalRules[ruleID] = rule + } + + b.cache.peerACLs[peerID] = view +} + +func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, + validatedPeersMap map[string]struct{}, +) ([]*nbpeer.Peer, []*FirewallRule) { + ctx := context.Background() + + peerID := peer.ID + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + fwRules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + for _, group := range peerGroups { + policies := b.cache.groupToPolicies[group] + for _, policy := range policies { + if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { + continue + } + rules := b.cache.policyToRules[policy.ID] + for _, rule := range rules { + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peerInSources = rule.SourceResource.ID == peerID + } else { + peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peerInDestinations = rule.DestinationResource.ID == peerID + } else { + peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) + } + + if !peerInSources && !peerInDestinations { + continue + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peer := account.GetPeer(rule.SourceResource.ID) + if peer != nil { + sourcePeers = []*nbpeer.Peer{peer} + } + } else { + sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peer := account.GetPeer(rule.DestinationResource.ID) + if peer != nil { + destinationPeers = []*nbpeer.Peer{peer} + } + } else { + destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + } + } + + return peers, fwRules +} + +func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { + for _, groupID := range groupIDs { + if _, exists := peerGroupsMap[groupID]; exists { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, + excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, +) []*nbpeer.Peer { + ctx := context.Background() + uniquePeers := make(map[string]*nbpeer.Peer) + + for _, groupID := range groupIDs { + peerIDs := b.cache.groupToPeers[groupID] + for _, peerID := range peerIDs { + if peerID == excludePeerID { + continue + } + + if _, ok := validatedPeersMap[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + if len(postureChecksIDs) > 0 { + if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { + continue + } + } + + uniquePeers[peerID] = peer + } + } + + result := make([]*nbpeer.Peer, 0, len(uniquePeers)) + for _, peer := range uniquePeers { + result = append(result, peer) + } + + return result +} + +func (b *NetworkMapBuilder) generateResourcescached( + account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, +) { + isAll := false + if allGroup, err := account.GetGroupAll(); err == nil { + isAll = (len(allGroup.Peers) - 1) == len(groupPeers) + } + + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + *peers = append(*peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = allPeers + } + + var s strings.Builder + s.WriteString(rule.ID) + s.WriteString(fr.PeerIP) + s.WriteString(strconv.Itoa(direction)) + s.WriteString(fr.Protocol) + s.WriteString(fr.Action) + s.WriteString(strings.Join(rule.Ports, ",")) + + ruleID := s.String() + + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + *rules = append(*rules, &fr) + continue + } + + *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } +} + +func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { + ctx := context.Background() + peerID := peer.ID + + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, resource := range b.cache.globalResources { + + networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] + resourcePolicies := b.cache.resourcePolicies[resource.ID] + if len(resourcePolicies) == 0 { + continue + } + + isRouterForThisResource := false + + if networkRoutingPeers != nil { + if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { + isRoutingPeer = true + isRouterForThisResource = true + if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + + hasAccessAsClient := false + if !isRouterForThisResource { + for _, policy := range resourcePolicies { + if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + hasAccessAsClient = true + break + } + } + } + } + + if hasAccessAsClient && networkRoutingPeers != nil { + for routerPeerID, router := range networkRoutingPeers { + if router.Enabled { + if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + } + + if isRouterForThisResource { + for _, policy := range resourcePolicies { + var peersWithAccess []*nbpeer.Peer + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peersWithAccess = []*nbpeer.Peer{peer} + } else { + peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) + } + for _, p := range peersWithAccess { + allSourcePeers[p.ID] = struct{}{} + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (b *NetworkMapBuilder) createNetworkResourceRoutes( + resource *resourceTypes.NetworkResource, routerPeerID string, + router *routerTypes.NetworkRouter, resourcePolicies []*Policy, +) *route.Route { + if len(resourcePolicies) > 0 { + peer := b.cache.globalPeers[routerPeerID] + if peer != nil { + return resource.ToRoute(peer, router) + } + } + return nil +} + +func (b *NetworkMapBuilder) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { + ctx := context.Background() + peer := account.GetPeer(peerID) + if peer == nil { + return + } + resourcePolicies := b.cache.resourcePolicies + + view := &PeerRoutesView{ + OwnRouteIDs: make([]route.ID, 0), + NetworkResourceIDs: make([]route.ID, 0), + RouteFirewallRuleIDs: make([]string, 0), + } + + enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) + for _, rt := range enabledRoutes { + if rt.PeerID != "" && rt.PeerID != peerID { + if b.cache.globalPeers[rt.PeerID] == nil { + continue + } + } + + view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + aclView := b.cache.peerACLs[peerID] + if aclView != nil { + peerRoutesMembership := make(LookupMap) + for _, r := range append(enabledRoutes, disabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(LookupMap) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, aclPeerID := range aclView.ConnectedPeerIDs { + if aclPeerID == peerID { + continue + } + activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) + groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) + haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + + for _, inheritedRoute := range haFilteredRoutes { + view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) + b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute + } + } + } + + _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) + + for _, rt := range networkResourcesRoutes { + view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) + b.updateACGIndexForPeer(peerID, allRoutes) + + routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) + for _, rule := range routeFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + + if len(networkResourcesRoutes) > 0 { + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + for _, rule := range networkResourceFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + } + + b.cache.peerRoutes[peerID] = view +} + +func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == peerID { + delete(b.cache.noACGRoutes, routeID) + } + } + + for _, rt := range routes { + if !rt.Enabled { + continue + } + + if len(rt.AccessControlGroups) == 0 { + b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } else { + for _, acg := range rt.AccessControlGroups { + if b.cache.acgToRoutes[acg] == nil { + b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) + } + + b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } + } + } +} + +func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peer := b.cache.globalPeers[peerID] + if peer == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + // maybe here is some mess - here we store peer key (see comment below) + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + peerGroups := b.cache.peerToGroups[peerID] + for _, groupID := range peerGroups { + groupRoutes := b.cache.groupToRoutes[groupID] + for _, r := range groupRoutes { + newPeerRoute := r.Copy() + // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes + newPeerRoute.Peer = peerID + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) + takeRoute(newPeerRoute, peerID) + } + } + for _, r := range b.cache.peerToRoutes[peerID] { + takeRoute(r.Copy(), peerID) + } + return enabledRoutes, disabledRoutes +} + +func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) + for _, route := range enabledRoutes { + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := b.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) + + rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make(map[string]*Policy) + + for _, groupID := range accessControlGroups { + candidatePolicies := b.cache.groupToPolicies[groupID] + + for _, policy := range candidatePolicies { + if _, found := routePolicies[policy.ID]; found { + continue + } + policyRules := b.cache.policyToRules[policy.ID] + for _, rule := range policyRules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies[policy.ID] = policy + break + } + } + } + } + + return maps.Values(routePolicies) +} + +func (b *NetworkMapBuilder) getRouteFirewallRules( + peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, + distributionPeers map[string]struct{}, account *Account, +) []*RouteFirewallRule { + ctx := context.Background() + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) + + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (b *NetworkMapBuilder) getRulePeers( + rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, + validatedPeersMap map[string]struct{}, account *Account, +) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + + for _, id := range rule.Sources { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := b.cache.globalPeers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { + peerGroups := b.cache.peerToGroups[peerID] + checkGroups := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + checkGroups[groupID] = struct{}{} + } + + dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) + dnsConfig := &nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) + } + + b.cache.peerDNS[peerID] = dnsConfig +} + +func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { + + enabled := true + for _, groupID := range account.DNSSettings.DisabledManagementGroups { + _, found := checkGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := checkGroups[gID] + if found { + peer := b.cache.globalPeers[peerID] + if !peerIsNameserver(peer, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { + b.account.Store(account) +} + +func (b *NetworkMapBuilder) GetPeerNetworkMap( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + account := b.account.Load() + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm +} + +func (b *NetworkMapBuilder) assembleNetworkMap( + account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, +) *NetworkMap { + + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, peerID := range aclView.ConnectedPeerIDs { + if _, ok := validatedPeers[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if account.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, peer) + } else { + peersToConnect = append(peersToConnect, peer) + } + } + + var routes []*route.Route + allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) + + for _, routeID := range allRouteIDs { + if route := b.cache.globalRoutes[routeID]; route != nil { + routes = append(routes, route) + } + } + + var firewallRules []*FirewallRule + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil { + firewallRules = append(firewallRules, rule) + } + } + + var routesFirewallRules []*RouteFirewallRule + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + + finalDNSConfig := *dnsConfig + if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + var zones []nbdns.CustomZone + records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: customZone.Domain, + Records: records, + }) + finalDNSConfig.CustomZones = zones + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: account.Network.Copy(), + Routes: routes, + DNSConfig: finalDNSConfig, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } +} + +func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { + var s strings.Builder + s.WriteString(fw) + s.WriteString(rule.PolicyID) + s.WriteRune(':') + s.WriteString(rule.PeerIP) + s.WriteRune(':') + s.WriteString(strconv.Itoa(rule.Direction)) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(rule.Port) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) + s.WriteRune('-') + s.WriteString(strconv.Itoa(int(rule.PortRange.End))) + return s.String() +} + +func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { + var s strings.Builder + s.WriteString(rfw) + s.WriteString(string(rule.RouteID)) + s.WriteRune(':') + s.WriteString(rule.Destination) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(strings.Join(rule.SourceRanges, ",")) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.Port))) + return s.String() +} + +func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(peerGroups, groupID) { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { + for _, r := range account.Routes { + if !r.Enabled { + continue + } + + if r.PeerID == peerID { + return true + } + + if peer := b.cache.globalPeers[peerID]; peer != nil { + if r.Peer == peer.Key && r.PeerID == "" { + return true + } + } + } + + routers := account.GetResourceRoutersMap() + for _, networkRouters := range routers { + if router, exists := networkRouters[peerID]; exists && router.Enabled { + return true + } + } + + return false +} + +type ViewDelta struct { + AddedPeerIDs []string + RemovedPeerIDs []string + AddedRuleIDs []string + RemovedRuleIDs []string +} + +func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { + tt := time.Now() + account := b.account.Load() + peer := account.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer %s not found in account", peerID) + } + + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) + + b.validatedPeers[peerID] = struct{}{} + + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) + + b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) + + log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) + + return nil +} + +func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { + peerGroups := make([]string, 0) + + for groupID, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { + b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) + } + peerGroups = append(peerGroups, groupID) + } + } + + b.cache.peerToGroups[peerID] = peerGroups + + for _, r := range account.Routes { + if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { + continue + } + for _, groupID := range r.PeerGroups { + if !slices.Contains(b.cache.groupToRoutes[groupID], r) { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } + b.cache.globalRoutes[r.ID] = r + } + + return peerGroups +} + +func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) + + if b.isPeerRouter(account, newPeerID) { + affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) + for affectedPeerID := range affectedByRoutes { + if affectedPeerID == newPeerID { + continue + } + if _, exists := updates[affectedPeerID]; !exists { + updates[affectedPeerID] = &PeerUpdateDelta{ + PeerID: affectedPeerID, + RebuildRoutesView: true, + } + } else { + updates[affectedPeerID].RebuildRoutesView = true + } + } + } + + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { + affected := make(map[string]struct{}) + enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) + + for _, route := range enabledRoutes { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + + for _, peerGroupID := range route.PeerGroups { + if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + } + + for _, route := range account.Routes { + if !route.Enabled { + continue + } + + routerInPeerGroups := false + for _, peerGroupID := range route.PeerGroups { + if slices.Contains(routerGroups, peerGroupID) { + routerInPeerGroups = true + break + } + } + + if routerInPeerGroups { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + affected[peerID] = struct{}{} + } + } + } + } + } + + return affected +} + +func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { + updates := make(map[string]*PeerUpdateDelta) + ctx := context.Background() + + groupAllLn := 0 + if allGroup, err := account.GetGroupAll(); err == nil { + groupAllLn = len(allGroup.Peers) - 1 + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return updates + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) + peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + + if rule.Bidirectional { + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + } + } + } + + b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) + + b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) + + b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) + + return updates +} + +func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( + ctx context.Context, account *Account, newPeerID string, + updates map[string]*PeerUpdateDelta, +) { + resourceRouters := b.cache.resourceRouters + + for networkID, routers := range resourceRouters { + router, isRouter := routers[newPeerID] + if !isRouter || !router.Enabled { + continue + } + + for _, resource := range b.cache.globalResources { + if resource.NetworkID != networkID { + continue + } + + policies := b.cache.resourcePolicies[resource.ID] + if len(policies) == 0 { + continue + } + + peersWithAccess := make(map[string]struct{}) + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + groupPeers := b.cache.groupToPeers[sourceGroup] + for _, peerID := range groupPeers { + if peerID == newPeerID { + continue + } + + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + peersWithAccess[peerID] = struct{}{} + } + } + } + } + + for peerID := range peersWithAccess { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + } + updates[peerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } + } +} + +func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( + newPeerID string, newPeer *nbpeer.Peer, + peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + processedPeerRoutes := make(map[string]map[route.ID]struct{}) + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == newPeerID { + continue + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + + for _, acg := range peerGroups { + routeInfos := b.cache.acgToRoutes[acg] + if routeInfos == nil { + continue + } + + for routeID, info := range routeInfos { + if info.PeerID == newPeerID { + continue + } + + if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { + if _, processed := processedRoutes[routeID]; processed { + continue + } + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + } +} + +func (b *NetworkMapBuilder) addRouteFirewallUpdate( + updates map[string]*PeerUpdateDelta, peerID string, + routeID string, sourceIP string, +) { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), + } + updates[peerID] = delta + } + + for _, existing := range delta.UpdateRouteFirewallRules { + if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { + return + } + } + + delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ + RuleID: routeID, + AddSourceIP: sourceIP, + }) +} + +func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( + ctx context.Context, account *Account, newPeerID string, + newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + for _, resource := range b.cache.globalResources { + resourcePolicies := b.cache.resourcePolicies + resourceRouters := b.cache.resourceRouters + + policies := resourcePolicies[resource.ID] + peerHasAccess := false + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + if slices.Contains(peerGroups, sourceGroup) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { + peerHasAccess = true + break + } + } + } + + if peerHasAccess { + break + } + } + + if !peerHasAccess { + continue + } + + networkRouters := resourceRouters[resource.NetworkID] + for routerPeerID, router := range networkRouters { + if !router.Enabled || routerPeerID == newPeerID { + continue + } + + delta := updates[routerPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: routerPeerID, + } + updates[routerPeerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } +} + +type PeerUpdateDelta struct { + PeerID string + AddConnectedPeer string + AddFirewallRules []*FirewallRuleDelta + AddRoutes []route.ID + UpdateRouteFirewallRules []*RouteFirewallRuleUpdate + UpdateDNS bool + RebuildRoutesView bool +} +type FirewallRuleDelta struct { + Rule *FirewallRule + RuleID string + Direction int +} + +type RouteFirewallRuleUpdate struct { + RuleID string + AddSourceIP string +} + +func (b *NetworkMapBuilder) addUpdateForPeersInGroups( + updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, + rule *PolicyRule, direction int, allGroupLn int, +) { + for _, groupID := range groupIDs { + peers := b.cache.groupToPeers[groupID] + cnt := 0 + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + cnt++ + } + all := false + if allGroupLn > 0 && cnt == allGroupLn { + all = true + } + newPeer := b.cache.globalPeers[newPeerID] + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[peerID] = delta + } + + if all { + fr.PeerIP = allPeers + } + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(fr) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: fr, + RuleID: ruleID, + Direction: direction, + }) + } + } + } +} + +func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { + if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + } + + for _, ruleDelta := range delta.AddFirewallRules { + b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule + + if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { + aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) + } + } + } + } + + if delta.RebuildRoutesView { + b.buildPeerRoutesView(account, peerID) + } else if len(delta.UpdateRouteFirewallRules) > 0 { + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) + } + } + + if delta.UpdateDNS { + b.buildPeerDNSView(account, peerID) + } +} + +func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { + for _, update := range updates { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + rule := b.cache.globalRouteRules[ruleID] + if rule == nil { + continue + } + + if string(rule.RouteID) == update.RuleID { + sourceIP := update.AddSourceIP + + if strings.Contains(sourceIP, ":") { + sourceIP += "/128" // IPv6 + } else { + sourceIP += "/32" // IPv4 + } + + if !slices.Contains(rule.SourceRanges, sourceIP) { + rule.SourceRanges = append(rule.SourceRanges, sourceIP) + } + break + } + } + } +} + +func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.account.Load() + + deletedPeer := b.cache.globalPeers[peerID] + if deletedPeer == nil { + return fmt.Errorf("peer %s not found in cache", peerID) + } + + deletedPeerKey := deletedPeer.Key + peerGroups := b.cache.peerToGroups[peerID] + peerIP := deletedPeer.IP.String() + + log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) + + delete(b.validatedPeers, peerID) + + routesToDelete := []route.ID{} + + for routeID, r := range account.Routes { + if r.Peer != deletedPeerKey && r.PeerID != peerID { + continue + } + if len(r.PeerGroups) == 0 { + routesToDelete = append(routesToDelete, routeID) + continue + } + newPeerAssigned := false + for _, groupID := range r.PeerGroups { + candidatePeerIDs := b.cache.groupToPeers[groupID] + for _, candidatePeerID := range candidatePeerIDs { + if candidatePeerID == peerID { + continue + } + if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { + r.Peer = candidatePeer.Key + r.PeerID = candidatePeerID + newPeerAssigned = true + break + } + } + if newPeerAssigned { + break + } + } + + if !newPeerAssigned { + routesToDelete = append(routesToDelete, routeID) + } + } + + for _, routeID := range routesToDelete { + delete(account.Routes, routeID) + } + + delete(b.cache.peerACLs, peerID) + delete(b.cache.peerRoutes, peerID) + delete(b.cache.peerDNS, peerID) + + delete(b.cache.globalPeers, peerID) + + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for _, groupID := range peerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { + return id == peerID + }) + } + } + delete(b.cache.peerToGroups, peerID) + + affectedPeers := make(map[string]struct{}) + + for _, r := range account.Routes { + for _, groupID := range r.Groups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + + for _, groupID := range r.PeerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + } + + for affectedPeerID := range affectedPeers { + if affectedPeerID == peerID { + continue + } + b.buildPeerRoutesView(account, affectedPeerID) + } + + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + for affectedPeerID, updates := range peerDeletionUpdates { + b.applyDeletionUpdates(affectedPeerID, updates) + } + + b.cleanupUnusedRules() + + log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) + + return nil +} + +func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( + deletedPeerID string, + peerIP string, +) map[string]*PeerDeletionUpdate { + + affected := make(map[string]*PeerDeletionUpdate) + + for peerID, aclView := range b.cache.peerACLs { + if peerID == deletedPeerID { + continue + } + + if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + continue + } + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } + } + + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { + affected[peerID].RemoveFirewallRuleIDs = append( + affected[peerID].RemoveFirewallRuleIDs, + ruleID, + ) + } + } + } + + return affected +} + +type PeerDeletionUpdate struct { + RemovePeerID string + RemoveFirewallRuleIDs []string + RemoveRouteIDs []route.ID + RemoveFromSourceRanges bool + PeerIP string +} + +func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { + return id == updates.RemovePeerID + }) + + if len(updates.RemoveFirewallRuleIDs) > 0 { + aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) + }) + } + } + + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + if len(updates.RemoveRouteIDs) > 0 { + routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { + return slices.Contains(updates.RemoveRouteIDs, routeID) + }) + } + + if updates.RemoveFromSourceRanges { + b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) + } + } +} + +func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { + sourceIPv4 := peerIP + "/32" + sourceIPv6 := peerIP + "/128" + + rulesToRemove := []string{} + + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { + return source == sourceIPv4 || source == sourceIPv6 || source == peerIP + }) + + if len(rule.SourceRanges) == 0 { + rulesToRemove = append(rulesToRemove, ruleID) + } + } + } + + if len(rulesToRemove) > 0 { + routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(rulesToRemove, ruleID) + }) + } +} + +func (b *NetworkMapBuilder) cleanupUnusedRules() { + usedFirewallRules := make(map[string]struct{}) + usedRouteRules := make(map[string]struct{}) + usedRoutes := make(map[route.ID]struct{}) + + for _, aclView := range b.cache.peerACLs { + for _, ruleID := range aclView.FirewallRuleIDs { + usedFirewallRules[ruleID] = struct{}{} + } + } + + for _, routesView := range b.cache.peerRoutes { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + usedRouteRules[ruleID] = struct{}{} + } + + for _, routeID := range routesView.OwnRouteIDs { + usedRoutes[routeID] = struct{}{} + } + for _, routeID := range routesView.NetworkResourceIDs { + usedRoutes[routeID] = struct{}{} + } + } + + for ruleID := range b.cache.globalRules { + if _, used := usedFirewallRules[ruleID]; !used { + delete(b.cache.globalRules, ruleID) + } + } + + for ruleID := range b.cache.globalRouteRules { + if _, used := usedRouteRules[ruleID]; !used { + delete(b.cache.globalRouteRules, ruleID) + } + } + + for routeID := range b.cache.globalRoutes { + if _, used := usedRoutes[routeID]; !used { + delete(b.cache.globalRoutes, routeID) + } + } +} + +func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + peerStored, ok := b.cache.globalPeers[peer.ID] + if !ok { + return + } + *peerStored = *peer +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index da12f1b70..adf64592a 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,16 +7,14 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse - NetworkMap *types.NetworkMap + Update *proto.SyncResponse } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index 25c87df9c..66bea314f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -991,6 +991,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } } if len(peerIDs) != 0 { diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network, From 48475ddc058f71fccba80d622178d3f5abab79f0 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:50:18 +0100 Subject: [PATCH 045/118] [management] add pat rate limiting (#4741) --- go.mod | 2 +- management/server/http/handler.go | 42 ++- .../server/http/middleware/auth_middleware.go | 20 +- .../http/middleware/auth_middleware_test.go | 285 +++++++++++++++++- .../server/http/middleware/rate_limiter.go | 146 +++++++++ shared/management/http/util/util.go | 2 + shared/management/status/error.go | 3 + 7 files changed, 496 insertions(+), 4 deletions(-) create mode 100644 management/server/http/middleware/rate_limiter.go diff --git a/go.mod b/go.mod index 68a12908d..7b9bae321 100644 --- a/go.mod +++ b/go.mod @@ -108,6 +108,7 @@ require ( golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 + golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -245,7 +246,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3d4de31d0..4d2c224b4 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,9 +4,13 @@ import ( "context" "fmt" "net/http" + "os" + "strconv" + "time" "github.com/gorilla/mux" "github.com/rs/cors" + log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" @@ -38,7 +42,12 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" +const ( + apiPrefix = "/api" + rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" + rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" + rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" +) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler( @@ -58,11 +67,42 @@ func NewAPIHandler( settingsManager settings.Manager, ) (http.Handler, error) { + var rateLimitingConfig *middleware.RateLimiterConfig + if os.Getenv(rateLimitingEnabledKey) == "true" { + rpm := 6 + if v := os.Getenv(rateLimitingRPMKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) + } else { + rpm = value + } + } + + burst := 500 + if v := os.Getenv(rateLimitingBurstKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) + } else { + burst = value + } + } + + rateLimitingConfig = &middleware.RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + } + } + authMiddleware := middleware.NewAuthMiddleware( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, + rateLimitingConfig, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..bce917a25 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -29,6 +29,7 @@ type AuthMiddleware struct { ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc + rateLimiter *APIRateLimiter } // NewAuthMiddleware instance constructor @@ -37,12 +38,19 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, + rateLimiterConfig *RateLimiterConfig, ) *AuthMiddleware { + var rateLimiter *APIRateLimiter + if rateLimiterConfig != nil { + rateLimiter = NewAPIRateLimiter(rateLimiterConfig) + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, + rateLimiter: rateLimiter, } } @@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { request, err := m.checkPATFromRequest(r, auth) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) + // Check if it's a status error, otherwise default to Unauthorized + if _, ok := status.FromError(err); !ok { + err = status.Errorf(status.Unauthorized, "token invalid") + } + util.WriteError(r.Context(), err, w) return } h.ServeHTTP(w, request) @@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, fmt.Errorf("error extracting token: %w", err) } + if m.rateLimiter != nil { + if !m.rateLimiter.Allow(token) { + return r, status.Errorf(status.TooManyRequests, "too many requests") + } + } + ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..d1bd9959f 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -27,7 +27,9 @@ const ( domainCategory = "domainCategory" userID = "userID" tokenID = "tokenID" + tokenID2 = "tokenID2" PAT = "nbp_PAT" + PAT2 = "nbp_PAT2" JWT = "JWT" wrongToken = "wrongToken" ) @@ -49,6 +51,15 @@ var testAccount = &types.Account{ CreatedAt: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()), }, + tokenID2: { + ID: tokenID2, + Name: "My second token", + HashedToken: "someHash2", + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), + CreatedBy: userID, + CreatedAt: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), + }, }, }, }, @@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } + if token == PAT2 { + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil + } return nil, nil, "", "", fmt.Errorf("PAT invalid") } @@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA } func mockMarkPATUsed(_ context.Context, token string) error { - if token == tokenID { + if token == tokenID || token == tokenID2 { return nil } return fmt.Errorf("Should never get reached") @@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +func TestAuthMiddleware_RateLimiting(t *testing.T) { + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) { + // Configure rate limiter: 10 requests per minute with burst of 5 + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 5, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make burst requests - all should succeed + successCount := 0 + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 5, successCount, "All burst requests should succeed") + + // The 6th request should fail (exceeded burst) + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited") + }) + + t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) { + // Configure very low rate limit: 1 request per minute + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request should fail (rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) + + t.Run("Bearer Token Not Rate Limited", func(t *testing.T) { + // Configure strict rate limit + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make multiple requests with Bearer token - all should succeed + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)") + }) + + t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) { + // Configure rate limiter + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use first PAT token + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed") + + // Second request with same token should fail + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited") + + // Use second PAT token - should succeed because it has independent rate limit + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)") + + // Second request with PAT2 should also be rate limited + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited") + + // JWT should still work (not rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)") + }) + + t.Run("Rate Limiter Cleanup", func(t *testing.T) { + // Configure rate limiter with short cleanup interval and TTL for testing + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: 100 * time.Millisecond, + LimiterTTL: 200 * time.Millisecond, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request - should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request immediately - should fail (burst exhausted) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + + // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer) + time.Sleep(400 * time.Millisecond) + + // After cleanup, the limiter should be removed and recreated with full burst capacity + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)") + + // Verify it's a fresh limiter by checking burst is reset + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") + }) +} + func TestAuthMiddleware_Handler_Child(t *testing.T) { tt := []struct { name string @@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go new file mode 100644 index 000000000..a6266d4f3 --- /dev/null +++ b/management/server/http/middleware/rate_limiter.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterConfig holds configuration for the API rate limiter +type RateLimiterConfig struct { + // RequestsPerMinute defines the rate at which tokens are replenished + RequestsPerMinute float64 + // Burst defines the maximum number of requests that can be made in a burst + Burst int + // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs) + CleanupInterval time.Duration + // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal) + LimiterTTL time.Duration +} + +// DefaultRateLimiterConfig returns a default configuration +func DefaultRateLimiterConfig() *RateLimiterConfig { + return &RateLimiterConfig{ + RequestsPerMinute: 100, + Burst: 120, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } +} + +// limiterEntry holds a rate limiter and its last access time +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// APIRateLimiter manages rate limiting for API tokens +type APIRateLimiter struct { + config *RateLimiterConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + stopChan chan struct{} +} + +// NewAPIRateLimiter creates a new API rate limiter with the given configuration +func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { + if config == nil { + config = DefaultRateLimiterConfig() + } + + rl := &APIRateLimiter{ + config: config, + limiters: make(map[string]*limiterEntry), + stopChan: make(chan struct{}), + } + + go rl.cleanupLoop() + + return rl +} + +// Allow checks if a request for the given key (token) is allowed +func (rl *APIRateLimiter) Allow(key string) bool { + limiter := rl.getLimiter(key) + return limiter.Allow() +} + +// Wait blocks until the rate limiter allows another request for the given key +// Returns an error if the context is canceled +func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + limiter := rl.getLimiter(key) + return limiter.Wait(ctx) +} + +// getLimiter retrieves or creates a rate limiter for the given key +func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.RLock() + entry, exists := rl.limiters[key] + rl.mu.RUnlock() + + if exists { + rl.mu.Lock() + entry.lastAccess = time.Now() + rl.mu.Unlock() + return entry.limiter + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + if entry, exists := rl.limiters[key]; exists { + entry.lastAccess = time.Now() + return entry.limiter + } + + requestsPerSecond := rl.config.RequestsPerMinute / 60.0 + limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst) + rl.limiters[key] = &limiterEntry{ + limiter: limiter, + lastAccess: time.Now(), + } + + return limiter +} + +// cleanupLoop periodically removes old limiters that haven't been used recently +func (rl *APIRateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopChan: + return + } + } +} + +// cleanup removes limiters that haven't been used within the TTL period +func (rl *APIRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, entry := range rl.limiters { + if now.Sub(entry.lastAccess) > rl.config.LimiterTTL { + delete(rl.limiters, key) + } + } +} + +// Stop stops the cleanup goroutine +func (rl *APIRateLimiter) Stop() { + close(rl.stopChan) +} + +// Reset removes the rate limiter for a specific key +func (rl *APIRateLimiter) Reset(key string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.limiters, key) +} diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go index 3ae321023..0a29469da 100644 --- a/shared/management/http/util/util.go +++ b/shared/management/http/util/util.go @@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized case status.BadRequest: httpStatus = http.StatusBadRequest + case status.TooManyRequests: + httpStatus = http.StatusTooManyRequests default: } msg = strings.ToLower(err.Error()) diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..09676847e 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -37,6 +37,9 @@ const ( // Unauthenticated indicates that user is not authenticated due to absence of valid credentials Unauthenticated Type = 10 + + // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting) + TooManyRequests Type = 11 ) // Type is a type of the Error From 98ddac07bfb4159a6eeeb2178604116290147a4d Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:50:58 +0100 Subject: [PATCH 046/118] [management] remove toAll firewall rule (#4725) --- management/server/policy_test.go | 152 ++++++++++++++++++++++++++++- management/server/types/account.go | 11 --- 2 files changed, 148 insertions(+), 15 deletions(-) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4a08f4c33..97ebbcf5a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", @@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) { PolicyID: "RuleDefault", }, { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", @@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // We expect a single permissive firewall rule which all outgoing connections peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) - assert.Len(t, firewallRules, 1) + assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.80.39", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.21.56", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", diff --git a/management/server/types/account.go b/management/server/types/account.go index dd6052498..3818a84ce 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1062,14 +1062,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer rules := make([]*FirewallRule, 0) peers := make([]*nbpeer.Peer, 0) - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &Group{} - } - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { continue @@ -1088,10 +1081,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer Protocol: string(rule.Protocol), } - if isAll { - fr.PeerIP = "0.0.0.0" - } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") if _, ok := rulesExists[ruleID]; ok { From dbfc8a52c932a4a246c07938eeabfaef78f66c3b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:03:14 +0100 Subject: [PATCH 047/118] [management] remove GLOBAL when disabling foreign keys on mysql (#4615) --- management/server/store/sql_store.go | 60 +++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d83d160c3..3146254a3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -85,12 +85,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } - switch storeEngine { - case types.MysqlStoreEngine: - if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil { - return nil, err - } - case types.SqliteStoreEngine: + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -175,7 +170,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro group.StoreGroupPeers() } - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -270,7 +265,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -324,7 +319,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -613,7 +608,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error @@ -2932,6 +2927,16 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + repo := s.withTx(tx) err := operation(repo) if err != nil { @@ -2939,6 +2944,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor return err } + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to re-enable FK checks: %w", err) + } + } + err = tx.Commit().Error log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) @@ -2956,6 +2969,31 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// transaction wraps a GORM transaction with MySQL-specific FK checks handling +// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora +func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + + err := fn(tx) + + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine && err == nil { + if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil { + return fmt.Errorf("failed to re-enable FK checks: %w", fkErr) + } + } + + return err + }) +} + func (s *SqlStore) GetDB() *gorm.DB { return s.db } @@ -3212,7 +3250,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } From 7df49e249dc035ece05d2514144ddac67674201f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 20:14:52 +0100 Subject: [PATCH 048/118] [management ] remove timing logs (#4761) --- management/server/account.go | 5 ----- management/server/grpcserver.go | 10 --------- management/server/peer.go | 33 ---------------------------- management/server/store/sql_store.go | 13 ----------- management/server/types/account.go | 4 ---- 5 files changed, 65 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0aecbd586..f5a5c7b7a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1672,11 +1672,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) - }() - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 0a5236cb3..d3d94443a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -775,11 +775,6 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("toSyncResponse: took %s", time.Since(start)) - }() - response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ @@ -844,11 +839,6 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("sendInitialSync: took %s", time.Since(start)) - }() - var err error var turnToken *Token diff --git a/management/server/peer.go b/management/server/peer.go index 80ab7fc69..4c605b5eb 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var settings *types.Settings var expired bool @@ -744,11 +739,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -822,10 +812,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("handlePeerNotFound: took %s", time.Since(start)) - }() if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -847,11 +833,6 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("LoginPeer: took %s", time.Since(start)) - }() - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -965,11 +946,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getPostureChecks: took %s", time.Since(start)) - }() - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1058,11 +1034,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co } func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start)) - }() - if isRequiresApproval { network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1611,10 +1582,6 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getPeerGroupIDs: took %s", time.Since(start)) - }() return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 3146254a3..75f2c3ae7 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -311,10 +311,6 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SavePeer: took %s", time.Since(start)) - }() // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -2156,10 +2152,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetAccountNetwork: took %s", time.Since(start)) - }() ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -2201,11 +2193,6 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getAccountSettings: took %s", time.Since(start)) - }() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) diff --git a/management/server/types/account.go b/management/server/types/account.go index 3818a84ce..8797e1fa3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -265,10 +265,6 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start)) - }() - peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ From 07cf9d5895404b1a4e04fd8de17918d5c0edf8ce Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 8 Nov 2025 10:54:37 +0100 Subject: [PATCH 049/118] [client] Create networkd.conf.d if it doesn't exist (#4764) --- client/cmd/service_installer.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 2a87e538d..f6828d96a 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) { } const ( + networkdConf = "/etc/systemd/networkd.conf" networkdConfDir = "/etc/systemd/networkd.conf.d" networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing @@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no // configureSystemdNetworkd creates a drop-in configuration file to prevent // systemd-networkd from removing NetBird's routes and policy rules. func configureSystemdNetworkd() error { - parentDir := filepath.Dir(networkdConfDir) - if _, err := os.Stat(parentDir); os.IsNotExist(err) { - log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration") + if _, err := os.Stat(networkdConf); os.IsNotExist(err) { + log.Debug("systemd-networkd not in use, skipping configuration") return nil } + // nolint:gosec // standard networkd permissions + if err := os.MkdirAll(networkdConfDir, 0755); err != nil { + return fmt.Errorf("create networkd.conf.d directory: %w", err) + } + // nolint:gosec // standard networkd permissions if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { return fmt.Errorf("write networkd configuration: %w", err) From 56f169eedeb4f6a3b8d759bf7f735789322c8bc1 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 10 Nov 2025 23:43:08 +0100 Subject: [PATCH 050/118] [management] fix pg db deadlock after app panic (#4772) --- management/server/store/sql_store.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 75f2c3ae7..94b7fc1cc 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2914,6 +2914,23 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + if s.storeEngine == types.PostgresStoreEngine { + if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set statement timeout: %w", err) + } + if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set lock timeout: %w", err) + } + } // For MySQL, disable FK checks within this transaction to avoid deadlocks // This is session-scoped and doesn't require SUPER privileges From c28275611b82de1fa04bc022458ae97dd325d7fc Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 11 Nov 2025 13:59:32 +0100 Subject: [PATCH 051/118] Fix agent reference (#4776) --- client/internal/peer/worker_ice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 3675f0157..5d8ebfe45 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { if isController(w.config) { - return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } From cc97cffff1491d84d9cad39bae53a64998b6857b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:09:46 +0100 Subject: [PATCH 052/118] [management] move network map logic into new design (#4774) --- client/cmd/testutil_test.go | 15 +- client/internal/engine_test.go | 13 +- client/server/server_test.go | 13 +- go.mod | 2 +- .../controller/cache/dns_config_cache.go | 31 + .../network_map/controller/controller.go | 784 ++++++++++++++++++ .../network_map/controller/controller_test.go | 244 ++++++ .../network_map/controller/metrics.go | 15 + .../network_map/controller/repository.go | 39 + .../controllers/network_map/interface.go | 39 + .../controllers/network_map/interface_mock.go | 225 +++++ .../controllers/network_map/network_map.go | 1 + .../controllers/network_map/update_channel.go | 13 + .../update_channel}/updatechannel.go | 26 +- .../update_channel}/updatechannel_test.go | 7 +- .../controllers/network_map/update_message.go | 9 + management/internals/server/boot.go | 6 +- management/internals/server/controllers.go | 28 +- management/internals/server/modules.go | 3 +- .../internals/shared/grpc/conversion.go | 352 ++++++++ .../internals/shared/grpc/conversion_test.go | 150 ++++ .../shared/grpc}/loginfilter.go | 2 +- .../shared/grpc}/loginfilter_test.go | 2 +- .../shared/grpc/server.go} | 249 +----- .../internals/shared/grpc/server_test.go | 106 +++ .../shared/grpc}/token_mgr.go | 11 +- .../shared/grpc}/token_mgr_test.go | 12 +- management/server/account.go | 127 +-- management/server/account/manager.go | 11 +- management/server/account/request_buffer.go | 11 + management/server/account_test.go | 151 ++-- management/server/dns.go | 130 --- management/server/dns_test.go | 257 +----- management/server/event_test.go | 2 +- management/server/group.go | 26 +- management/server/group_test.go | 20 +- management/server/holder.go | 39 - management/server/http/handler.go | 4 +- .../http/handlers/peers/peers_handler.go | 23 +- .../http/handlers/peers/peers_handler_test.go | 29 +- .../testing/testing_tools/channel/channel.go | 20 +- management/server/management_proto_test.go | 142 +--- management/server/management_test.go | 21 +- management/server/mock_server/account_mock.go | 12 +- management/server/nameserver.go | 9 - management/server/nameserver_test.go | 16 +- management/server/networkmap.go | 80 -- management/server/networks/manager.go | 3 - .../server/networks/resources/manager.go | 9 - management/server/networks/routers/manager.go | 9 - management/server/peer.go | 446 +--------- management/server/peer_test.go | 235 ++---- management/server/policy.go | 35 - management/server/policy_test.go | 6 +- management/server/posture_checks.go | 72 -- management/server/posture_checks_test.go | 14 +- management/server/route.go | 101 --- management/server/route_test.go | 33 +- management/server/setupkey_test.go | 14 +- management/server/user.go | 23 +- management/server/user_test.go | 16 +- shared/management/client/client_test.go | 14 +- 62 files changed, 2568 insertions(+), 1989 deletions(-) create mode 100644 management/internals/controllers/network_map/controller/cache/dns_config_cache.go create mode 100644 management/internals/controllers/network_map/controller/controller.go create mode 100644 management/internals/controllers/network_map/controller/controller_test.go create mode 100644 management/internals/controllers/network_map/controller/metrics.go create mode 100644 management/internals/controllers/network_map/controller/repository.go create mode 100644 management/internals/controllers/network_map/interface.go create mode 100644 management/internals/controllers/network_map/interface_mock.go create mode 100644 management/internals/controllers/network_map/network_map.go create mode 100644 management/internals/controllers/network_map/update_channel.go rename management/{server => internals/controllers/network_map/update_channel}/updatechannel.go (87%) rename management/{server => internals/controllers/network_map/update_channel}/updatechannel_test.go (89%) create mode 100644 management/internals/controllers/network_map/update_message.go create mode 100644 management/internals/shared/grpc/conversion.go create mode 100644 management/internals/shared/grpc/conversion_test.go rename management/{server => internals/shared/grpc}/loginfilter.go (99%) rename management/{server => internals/shared/grpc}/loginfilter_test.go (99%) rename management/{server/grpcserver.go => internals/shared/grpc/server.go} (76%) create mode 100644 management/internals/shared/grpc/server_test.go rename management/{server => internals/shared/grpc}/token_mgr.go (93%) rename management/{server => internals/shared/grpc}/token_mgr_test.go (94%) create mode 100644 management/server/account/request_buffer.go delete mode 100644 management/server/holder.go delete mode 100644 management/server/networkmap.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index bd3209605..78bb0476b 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,9 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" @@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, nil @@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2f1098100..15ac0a947 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,6 +26,9 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index e0a4805f6..ae5f759ee 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -14,6 +14,9 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 7b9bae321..1e3177d7e 100644 --- a/go.mod +++ b/go.mod @@ -99,6 +99,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 @@ -242,7 +243,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/management/internals/controllers/network_map/controller/cache/dns_config_cache.go b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go new file mode 100644 index 000000000..8cc634ef4 --- /dev/null +++ b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go @@ -0,0 +1,31 @@ +package cache + +import ( + "sync" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// DNSConfigCache is a thread-safe cache for DNS configuration components +type DNSConfigCache struct { + NameServerGroups sync.Map +} + +// GetNameServerGroup retrieves a cached name server group +func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { + if c == nil { + return nil, false + } + if value, ok := c.NameServerGroups.Load(key); ok { + return value.(*proto.NameServerGroup), true + } + return nil, false +} + +// SetNameServerGroup stores a name server group in the cache +func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { + if c == nil { + return + } + c.NameServerGroups.Store(key, value) +} diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go new file mode 100644 index 000000000..ad25494c7 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller.go @@ -0,0 +1,784 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "os" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "golang.org/x/mod/semver" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util" +) + +type Controller struct { + repo Repository + metrics *metrics + // This should not be here, but we need to maintain it for the time being + accountManagerMetrics *telemetry.AccountManagerMetrics + peersUpdateManager network_map.PeersUpdateManager + settingsManager settings.Manager + + accountUpdateLocks sync.Map + sendAccountUpdateLocks sync.Map + updateAccountPeersBufferInterval atomic.Int64 + // dnsDomain is used for peer resolution. This is appended to the peer's name + dnsDomain string + + requestBuffer account.RequestBuffer + + proxyController port_forwarding.Controller + + integratedPeerValidator integrated_validator.IntegratedValidator + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} +} + +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} + +var _ network_map.Controller = (*Controller)(nil) + +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller { + nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) + if err != nil { + log.Fatal(fmt.Errorf("error creating metrics: %w", err)) + } + + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + + return &Controller{ + repo: newRepository(store), + metrics: nMetrics, + accountManagerMetrics: metrics.AccountManagerMetrics(), + peersUpdateManager: peersUpdateManager, + requestBuffer: requestBuffer, + integratedPeerValidator: integratedPeerValidator, + settingsManager: settingsManager, + dnsDomain: dnsDomain, + + proxyController: proxyController, + + holder: types.NewHolder(), + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, + } +} + +func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) + } + } + + globalStart := time.Now() + + hasPeersConnected := false + for _, peer := range account.Peers { + if c.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return nil + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validate peers: %v", err) + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + if c.experimentalNetworkMap(accountID) { + c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return fmt.Errorf("failed to get proxy network maps: %v", err) + } + + extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get flow enabled status: %v", err) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + for _, peer := range account.Peers { + if !c.peersUpdateManager.HasChannel(peer.ID) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + continue + } + + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + start := time.Now() + + postureChecks, err := c.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err) + return + } + + c.metrics.CountCalcPostureChecksDuration(time.Since(start)) + start = time.Now() + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + peerGroups := account.GetPeerGroups(p.ID) + start = time.Now() + update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + c.metrics.CountToSyncResponseDuration(time.Since(start)) + + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + }(peer) + } + + wg.Wait() + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart)) + } + + return nil +} + +func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.sendUpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.sendUpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +// UpdatePeers updates all peers that belong to an account. +// Should be called when changes have to be synced to peers. +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { + if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return fmt.Errorf("recalculate network map cache: %v", err) + } + + return c.sendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { + if !c.peersUpdateManager.HasChannel(peerId) { + return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) + } + + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) + } + + peer := account.GetPeer(peerId) + if peer == nil { + return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId) + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validated peers: %v", err) + } + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + postureChecks, err := c.getPeerPostureChecks(account, peerId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) + return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return err + } + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get extra settings: %v", err) + } + + peerGroups := account.GetPeerGroups(peerId) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + + return nil +} + +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountId) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountId) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerId) + return nil +} + +func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + if isRequiresApproval { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + + emptyMap := &types.NetworkMap{ + Network: network.Copy(), + } + return peer, emptyMap, nil, 0, nil + } + + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, nil, nil, 0, err + } + + startPosture := time.Now() + postureChecks, err := c.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, 0, err + } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) + + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, nil, nil, 0, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + return peer, networkMap, postureChecks, dnsFwdPort, nil +} + +func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + c.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (c *Controller) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := c.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := c.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + c.updateAccountInHolder(account) +} + +func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if c.experimentalNetworkMap(accountId) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + c.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (c *Controller) experimentalNetworkMap(accountId string) bool { + _, ok := c.expNewNetworkMapAIDs[accountId] + return c.expNewNetworkMap || ok +} + +func (c *Controller) enrichAccountFromHolder(account *types.Account) { + a := c.holder.GetAccount(account.Id) + if a == nil { + c.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + c.holder.AddAccount(account) +} + +func (c *Controller) getAccountFromHolder(accountID string) *types.Account { + return c.holder.GetAccount(accountID) +} + +func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account { + a := c.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (c *Controller) updateAccountInHolder(account *types.Account) { + c.holder.AddAccount(account) +} + +// GetDNSDomain returns the configured dnsDomain +func (c *Controller) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return c.dnsDomain + } + if settings.DNSDomain == "" { + return c.dnsDomain + } + + return settings.DNSDomain +} + +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err + } + } + + return maps.Values(peerPostureChecks), nil +} + +func (c *Controller) StartWarmup(ctx context.Context) { + var initialInterval int64 + intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") + interval, err := strconv.Atoi(intervalStr) + if err != nil { + initialInterval = 1 + log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) + } else { + initialInterval = int64(interval) * 10 + go func() { + startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") + startupPeriod, err := strconv.Atoi(startupPeriodStr) + if err != nil { + startupPeriod = 1 + log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) + } + time.Sleep(time.Duration(startupPeriod) * time.Second) + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) + }() + } + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) + +} + +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return int64(network_map.OldForwarderPort) + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return int64(network_map.OldForwarderPort) + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return int64(network_map.OldForwarderPort) + } + } + + // All peers have the required version or newer + return int64(network_map.DnsForwarderPort) +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) + if err != nil { + return err + } + + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck := account.GetPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil +} + +// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, sourceGroup := range rule.Sources { + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") + } + + if slices.Contains(group.Peers, peerID) { + return true, nil + } + } + } + + return false, nil +} + +func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { + c.UpdatePeerInNetworkMapCache(accountId, peer) + _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +} + +func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) +func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + account, err := c.repo.GetAccountByPeerID(ctx, peerID) + if err != nil { + return nil, err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(peer.AccountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + return networkMap, nil +} + +func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { + c.peersUpdateManager.CloseChannels(ctx, peerIDs) +} + +func (c *Controller) IsConnected(peerID string) bool { + return c.peersUpdateManager.HasChannel(peerID) +} diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go new file mode 100644 index 000000000..baaffe677 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -0,0 +1,244 @@ +package controller + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.DnsForwarderPort) { + t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) + } +} + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +} diff --git a/management/internals/controllers/network_map/controller/metrics.go b/management/internals/controllers/network_map/controller/metrics.go new file mode 100644 index 000000000..5832d2130 --- /dev/null +++ b/management/internals/controllers/network_map/controller/metrics.go @@ -0,0 +1,15 @@ +package controller + +import ( + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type metrics struct { + *telemetry.UpdateChannelMetrics +} + +func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) { + return &metrics{ + updateChannelMetrics, + }, nil +} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go new file mode 100644 index 000000000..44144263b --- /dev/null +++ b/management/internals/controllers/network_map/controller/repository.go @@ -0,0 +1,39 @@ +package controller + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Repository interface { + GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) + GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) +} + +type repository struct { + store store.Store +} + +var _ Repository = (*repository)(nil) + +func newRepository(s store.Store) Repository { + return &repository{ + store: s, + } +} + +func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) { + return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) +} + +func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) { + return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { + return r.store.GetAccountByPeerID(ctx, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go new file mode 100644 index 000000000..6f893ce79 --- /dev/null +++ b/management/internals/controllers/network_map/interface.go @@ -0,0 +1,39 @@ +package network_map + +//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" + + DnsForwarderPort = nbdns.ForwarderServerPort + OldForwarderPort = nbdns.ForwarderClientPort + DnsForwarderPortMinVersion = "v0.59.0" +) + +type Controller interface { + UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string) error + GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + GetDNSDomain(settings *types.Settings) string + StartWarmup(context.Context) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + + DeletePeer(ctx context.Context, accountId string, peerId string) error + + OnPeerUpdated(accountId string, peer *nbpeer.Peer) + OnPeerAdded(ctx context.Context, accountID string, peerID string) error + OnPeerDeleted(ctx context.Context, accountID string, peerID string) error + DisconnectPeers(ctx context.Context, peerIDs []string) + IsConnected(peerID string) bool +} diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go new file mode 100644 index 000000000..aaa093e47 --- /dev/null +++ b/management/internals/controllers/network_map/interface_mock.go @@ -0,0 +1,225 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +// + +// Package network_map is a generated GoMock package. +package network_map + +import ( + context "context" + reflect "reflect" + + peer "github.com/netbirdio/netbird/management/server/peer" + posture "github.com/netbirdio/netbird/management/server/posture" + types "github.com/netbirdio/netbird/management/server/types" + gomock "go.uber.org/mock/gomock" +) + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder + isgomock struct{} +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// BufferUpdateAccountPeers mocks base method. +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) +} + +// DeletePeer mocks base method. +func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeer indicates an expected call of DeletePeer. +func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) +} + +// DisconnectPeers mocks base method. +func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) +} + +// DisconnectPeers indicates an expected call of DisconnectPeers. +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) +} + +// GetDNSDomain mocks base method. +func (m *MockController) GetDNSDomain(settings *types.Settings) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNSDomain", settings) + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDNSDomain indicates an expected call of GetDNSDomain. +func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings) +} + +// GetNetworkMap mocks base method. +func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID) + ret0, _ := ret[0].(*types.NetworkMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNetworkMap indicates an expected call of GetNetworkMap. +func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID) +} + +// GetValidatedPeerWithMap mocks base method. +func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. +func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) +} + +// IsConnected mocks base method. +func (m *MockController) IsConnected(peerID string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsConnected", peerID) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsConnected indicates an expected call of IsConnected. +func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) +} + +// OnPeerAdded mocks base method. +func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerAdded indicates an expected call of OnPeerAdded. +func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) +} + +// OnPeerDeleted mocks base method. +func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerDeleted indicates an expected call of OnPeerDeleted. +func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) +} + +// OnPeerUpdated mocks base method. +func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) +} + +// OnPeerUpdated indicates an expected call of OnPeerUpdated. +func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) +} + +// StartWarmup mocks base method. +func (m *MockController) StartWarmup(arg0 context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartWarmup", arg0) +} + +// StartWarmup indicates an expected call of StartWarmup. +func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0) +} + +// UpdateAccountPeer mocks base method. +func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeer indicates an expected call of UpdateAccountPeer. +func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId) +} + +// UpdateAccountPeers mocks base method. +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) +} diff --git a/management/internals/controllers/network_map/network_map.go b/management/internals/controllers/network_map/network_map.go new file mode 100644 index 000000000..e915c2193 --- /dev/null +++ b/management/internals/controllers/network_map/network_map.go @@ -0,0 +1 @@ +package network_map diff --git a/management/internals/controllers/network_map/update_channel.go b/management/internals/controllers/network_map/update_channel.go new file mode 100644 index 000000000..0b085b85f --- /dev/null +++ b/management/internals/controllers/network_map/update_channel.go @@ -0,0 +1,13 @@ +package network_map + +import "context" + +type PeersUpdateManager interface { + SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) + CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage + CloseChannel(ctx context.Context, peerID string) + CountStreams() int + HasChannel(peerID string) bool + CloseChannels(ctx context.Context, peerIDs []string) + GetAllConnectedPeers() map[string]struct{} +} diff --git a/management/server/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go similarity index 87% rename from management/server/updatechannel.go rename to management/internals/controllers/network_map/update_channel/updatechannel.go index adf64592a..5f7db5300 100644 --- a/management/server/updatechannel.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel.go @@ -1,4 +1,4 @@ -package server +package update_channel import ( "context" @@ -7,36 +7,34 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 -type UpdateMessage struct { - Update *proto.SyncResponse -} - type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage + peerChannels map[string]chan *network_map.UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } +var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil) + // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), + peerChannels: make(map[string]chan *network_map.UpdateMessage), channelsMux: &sync.RWMutex{}, metrics: metrics, } } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) { start := time.Now() var found, dropped bool @@ -64,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage { start := time.Now() closed := false @@ -83,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c close(channel) } // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) + channel := make(chan *network_map.UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) @@ -174,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +func (p *PeersUpdateManager) CountStreams() int { + p.channelsMux.RLock() + defer p.channelsMux.RUnlock() + return len(p.peerChannels) +} diff --git a/management/server/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go similarity index 89% rename from management/server/updatechannel_test.go rename to management/internals/controllers/network_map/update_channel/updatechannel_test.go index 0dc86563d..afc1e2c32 100644 --- a/management/server/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -1,10 +1,11 @@ -package server +package update_channel import ( "context" "testing" "time" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &UpdateMessage{Update: &proto.SyncResponse{ + update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 0, }, @@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &UpdateMessage{Update: &proto.SyncResponse{ + update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 10, }, diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go new file mode 100644 index 000000000..33643bcbd --- /dev/null +++ b/management/internals/controllers/network_map/update_message.go @@ -0,0 +1,9 @@ +package network_map + +import ( + "github.com/netbirdio/netbird/shared/management/proto" +) + +type UpdateMessage struct { + Update *proto.SyncResponse +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 16e93a549..eadd16c2d 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -22,7 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" @@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) + srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index ddd81daa2..b61e33688 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" @@ -14,9 +18,9 @@ import ( "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) -func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { - return Create(s, func() *server.PeersUpdateManager { - return server.NewPeersUpdateManager(s.Metrics()) +func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { + return Create(s, func() *update_channel.PeersUpdateManager { + return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager { - return Create(s, func() *server.TimeBasedAuthSecretsManager { - return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { + return Create(s, func() *grpc.TimeBasedAuthSecretsManager { + return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) }) } @@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager { return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } + +func (s *BaseServer) NetworkMapController() network_map.Controller { + return Create(s, func() *nmapcontroller.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController()) + }) +} + +func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { + return Create(s, func() *server.AccountRequestBuffer { + return server.NewAccountRequestBuffer(context.Background(), s.Store()) + }) +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 209a20065..409bdaaba 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, - s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go new file mode 100644 index 000000000..9a4681eae --- /dev/null +++ b/management/internals/shared/grpc/conversion.go @@ -0,0 +1,352 @@ +package grpc + +import ( + "context" + "fmt" + + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { + if config == nil { + return nil + } + + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + + var turns []*proto.ProtectedHostConfig + if config.TURNConfig != nil { + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Payload + password = turnCredentials.Signature + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + } + + var relayCfg *proto.RelayConfig + if config.Relay != nil && len(config.Relay.Addresses) > 0 { + relayCfg = &proto.RelayConfig{ + Urls: config.Relay.Addresses, + } + + if relayToken != nil { + relayCfg.TokenPayload = relayToken.Payload + relayCfg.TokenSignature = relayToken.Signature + } + } + + var signalCfg *proto.HostConfig + if config.Signal != nil { + signalCfg = &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + } + } + + nbConfig := &proto.NetbirdConfig{ + Stuns: stuns, + Turns: turns, + Signal: signalCfg, + Relay: relayCfg, + } + + return nbConfig +} + +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, + LazyConnectionEnabled: settings.LazyConnectionEnabled, + } +} + +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + response := &proto.SyncResponse{ + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + }, + Checks: toProtocolChecks(ctx, checks), + } + + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + response.NetbirdConfig = extendedConfig + + response.NetworkMap.PeerConfig = response.PeerConfig + + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 + response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty + + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + response.NetworkMap.FirewallRules = firewallRules + response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + + if networkMap.ForwardingRules != nil { + forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) + for _, rule := range networkMap.ForwardingRules { + forwardingRules = append(forwardingRules, rule.ToProto()) + } + response.NetworkMap.ForwardingRules = forwardingRules + } + + return response +} + +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{rPeer.IP.String() + "/32"}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + AgentVersion: rPeer.Meta.WtVersion, + }) + } + return dst +} + +// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache +func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, + } + + for _, zone := range update.CustomZones { + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + } + + for _, nsGroup := range update.NameServerGroups { + cacheKey := nsGroup.ID + if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + } else { + protoGroup := convertToProtoNameServerGroup(nsGroup) + cache.SetNameServerGroup(cacheKey, protoGroup) + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) + } + } + + return protoUpdate +} + +func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { + switch configProto { + case nbconfig.UDP: + return proto.HostConfig_UDP + case nbconfig.DTLS: + return proto.HostConfig_DTLS + case nbconfig.HTTP: + return proto.HostConfig_HTTP + case nbconfig.HTTPS: + return proto.HostConfig_HTTPS + case nbconfig.TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toProtocolRoutes(routes []*route.Route) []*proto.Route { + protoRoutes := make([]*proto.Route, 0, len(routes)) + for _, r := range routes { + protoRoutes = append(protoRoutes, toProtocolRoute(r)) + } + return protoRoutes +} + +func toProtocolRoute(route *route.Route) *proto.Route { + return &proto.Route{ + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, + } +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + fwRule := &proto.FirewallRule{ + PolicyID: []byte(rule.PolicyID), + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, + } + + if shouldUsePortRange(fwRule) { + fwRule.PortInfo = rule.PortRange.ToProto() + } + + result[i] = fwRule + } + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == types.FirewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), + PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), + } + } + + return result +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(types.PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case types.PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case types.PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case types.PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} + +func shouldUsePortRange(rule *proto.FirewallRule) bool { + return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) +} + +// Helper function to convert nbdns.CustomZone to proto.CustomZone +func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + } + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, + }) + } + return protoZone +} + +// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup +func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup +} diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go new file mode 100644 index 000000000..701271345 --- /dev/null +++ b/management/internals/shared/grpc/conversion_test.go @@ -0,0 +1,150 @@ +package grpc + +import ( + "fmt" + "net/netip" + "reflect" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" +) + +func TestToProtocolDNSConfigWithCache(t *testing.T) { + var cache cache.DNSConfigCache + + // Create two different configs + config1 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.com", + Records: []nbdns.SimpleRecord{ + {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group1", + Name: "Group 1", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, + }, + }, + }, + } + + config2 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.org", + Records: []nbdns.SimpleRecord{ + {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group2", + Name: "Group 2", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, + }, + }, + }, + } + + // First run with config1 + result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Second run with config2 + result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) + + // Third run with config1 again + result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Verify that result1 and result3 are identical + if !reflect.DeepEqual(result1, result3) { + t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) + } + + // Verify that result2 is different from result1 and result3 + if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { + t.Errorf("Results should be different for different inputs") + } + + if _, exists := cache.GetNameServerGroup("group1"); !exists { + t.Errorf("Cache should contain name server group 'group1'") + } + + if _, exists := cache.GetNameServerGroup("group2"); !exists { + t.Errorf("Cache should contain name server group 'group2'") + } +} + +func BenchmarkToProtocolDNSConfig(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + testData := generateTestData(size) + + b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { + cache := &cache.DNSConfigCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + + b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := &cache.DNSConfigCache{} + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + } +} + +func generateTestData(size int) nbdns.Config { + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: make([]nbdns.CustomZone, size), + NameServerGroups: make([]*nbdns.NameServerGroup, size), + } + + for i := 0; i < size; i++ { + config.CustomZones[i] = nbdns.CustomZone{ + Domain: fmt.Sprintf("domain%d.com", i), + Records: []nbdns.SimpleRecord{ + { + Name: fmt.Sprintf("record%d", i), + Type: 1, + Class: "IN", + TTL: 3600, + RData: "192.168.1.1", + }, + }, + } + + config.NameServerGroups[i] = &nbdns.NameServerGroup{ + ID: fmt.Sprintf("group%d", i), + Primary: i == 0, + Domains: []string{fmt.Sprintf("domain%d.com", i)}, + SearchDomainsEnabled: true, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + Port: 53, + NSType: 1, + }, + }, + } + } + + return config +} diff --git a/management/server/loginfilter.go b/management/internals/shared/grpc/loginfilter.go similarity index 99% rename from management/server/loginfilter.go rename to management/internals/shared/grpc/loginfilter.go index 8604af6e2..59f69dd90 100644 --- a/management/server/loginfilter.go +++ b/management/internals/shared/grpc/loginfilter.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go similarity index 99% rename from management/server/loginfilter_test.go rename to management/internals/shared/grpc/loginfilter_test.go index 65782dd9d..8b26e14ab 100644 --- a/management/server/loginfilter_test.go +++ b/management/internals/shared/grpc/loginfilter_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/grpcserver.go b/management/internals/shared/grpc/server.go similarity index 76% rename from management/server/grpcserver.go rename to management/internals/shared/grpc/server.go index d3d94443a..08a840316 100644 --- a/management/server/grpcserver.go +++ b/management/internals/shared/grpc/server.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -22,7 +22,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -51,13 +51,13 @@ const ( defaultSyncLim = 1000 ) -// GRPCServer an instance of a Management gRPC API server -type GRPCServer struct { +// Server an instance of a Management gRPC API server +type Server struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager + peersUpdateManager network_map.PeersUpdateManager config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics @@ -69,23 +69,27 @@ type GRPCServer struct { blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + loginFilter *loginFilter + + networkMapController network_map.Controller + syncSem atomic.Int32 syncLim int32 } // NewServer creates a new Management server func NewServer( - ctx context.Context, config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager *PeersUpdateManager, + peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, -) (*GRPCServer, error) { + networkMapController network_map.Controller, +) (*Server, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -94,7 +98,7 @@ func NewServer( if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(len(peersUpdateManager.peerChannels)) + return int64(peersUpdateManager.CountStreams()) }) if err != nil { return nil, err @@ -115,7 +119,7 @@ func NewServer( } } - return &GRPCServer{ + return &Server{ wgKey: key, // peerKey -> event channel peersUpdateManager: peersUpdateManager, @@ -129,12 +133,15 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + networkMapController: networkMapController, + + loginFilter: newLoginFilter(), syncLim: syncLim, }, nil } -func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { +func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { ip := "" p, ok := peer.FromContext(ctx) if ok { @@ -171,7 +178,7 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) -func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { if s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } @@ -191,7 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } @@ -245,35 +252,29 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + metahash := metaHash(peerMeta, realIP.String()) + s.loginFilter.addLogin(peerKey.String(), metahash) + + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) - - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return err } - log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) - s.ephemeralManager.OnPeerConnected(ctx, peer) - log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) - s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) - if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -281,15 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi unlock() unlock = nil - log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) - s.syncSem.Add(-1) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -323,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) @@ -341,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() @@ -356,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } -func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -390,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return userAuth.UserId, nil } -func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { +func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) start := time.Now() @@ -498,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee } } -func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) @@ -517,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it is, the login is successful // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful -func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() @@ -531,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } @@ -628,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } -func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { +func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { var relayToken *Token var err error if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { @@ -647,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings), Checks: toProtocolChecks(ctx, postureChecks), } @@ -659,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // // The user ID can be empty if the token is not provided, which is acceptable if the peer is already // registered or if it uses a setup key to register. -func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { +func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { userID := "" if loginReq.GetJwtToken() != "" { var err error @@ -679,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR return userID, nil } -func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { - switch configProto { - case nbconfig.UDP: - return proto.HostConfig_UDP - case nbconfig.DTLS: - return proto.HostConfig_DTLS - case nbconfig.HTTP: - return proto.HostConfig_HTTP - case nbconfig.HTTPS: - return proto.HostConfig_HTTPS - case nbconfig.TCP: - return proto.HostConfig_TCP - default: - panic(fmt.Errorf("unexpected config protocol type %v", configProto)) - } -} - -func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { - if config == nil { - return nil - } - - var stuns []*proto.HostConfig - for _, stun := range config.Stuns { - stuns = append(stuns, &proto.HostConfig{ - Uri: stun.URI, - Protocol: ToResponseProto(stun.Proto), - }) - } - - var turns []*proto.ProtectedHostConfig - if config.TURNConfig != nil { - for _, turn := range config.TURNConfig.Turns { - var username string - var password string - if turnCredentials != nil { - username = turnCredentials.Payload - password = turnCredentials.Signature - } else { - username = turn.Username - password = turn.Password - } - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: turn.URI, - Protocol: ToResponseProto(turn.Proto), - }, - User: username, - Password: password, - }) - } - } - - var relayCfg *proto.RelayConfig - if config.Relay != nil && len(config.Relay.Addresses) > 0 { - relayCfg = &proto.RelayConfig{ - Urls: config.Relay.Addresses, - } - - if relayToken != nil { - relayCfg.TokenPayload = relayToken.Payload - relayCfg.TokenSignature = relayToken.Signature - } - } - - var signalCfg *proto.HostConfig - if config.Signal != nil { - signalCfg = &proto.HostConfig{ - Uri: config.Signal.URI, - Protocol: ToResponseProto(config.Signal.Proto), - } - } - - nbConfig := &proto.NetbirdConfig{ - Stuns: stuns, - Turns: turns, - Signal: signalCfg, - Relay: relayCfg, - } - - return nbConfig -} - -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { - netmask, _ := network.Net.Mask.Size() - fqdn := peer.FQDN(dnsName) - return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, - RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, - LazyConnectionEnabled: settings.LazyConnectionEnabled, - } -} - -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), - NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), - }, - Checks: toProtocolChecks(ctx, checks), - } - - nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) - extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) - response.NetbirdConfig = extendedConfig - - response.NetworkMap.PeerConfig = response.PeerConfig - - remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) - response.RemotePeers = remotePeers - response.NetworkMap.RemotePeers = remotePeers - response.RemotePeersIsEmpty = len(remotePeers) == 0 - response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) - - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) - response.NetworkMap.FirewallRules = firewallRules - response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 - - routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) - response.NetworkMap.RoutesFirewallRules = routesFirewallRules - response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 - - if networkMap.ForwardingRules != nil { - forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) - for _, rule := range networkMap.ForwardingRules { - forwardingRules = append(forwardingRules, rule.ToProto()) - } - response.NetworkMap.ForwardingRules = forwardingRules - } - - return response -} - -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - for _, rPeer := range peers { - dst = append(dst, &proto.RemotePeerConfig{ - WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, - SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, - Fqdn: rPeer.FQDN(dnsName), - AgentVersion: rPeer.Meta.WtVersion, - }) - } - return dst -} - // IsHealthy indicates whether the service is healthy -func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { +func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error { var err error var turnToken *Token @@ -862,19 +708,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID) + peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID) if err != nil { return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - // Get all peers in the account for forwarder port computation - allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") - if err != nil { - return fmt.Errorf("get account peers: %w", err) - } - dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) - - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -899,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p // GetDeviceAuthorizationFlow returns a device authorization flow information // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -957,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -1012,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En // SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, // peer's under the same account of any updates. -func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) @@ -1037,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } -func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) start := time.Now() diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go new file mode 100644 index 000000000..9867b38e3 --- /dev/null +++ b/management/internals/shared/grpc/server_test.go @@ -0,0 +1,106 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { + testingServerKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testingClientKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testCases := []struct { + name string + inputFlow *config.DeviceAuthorizationFlow + expectedFlow *mgmtProto.DeviceAuthorizationFlow + expectedErrFunc require.ErrorAssertionFunc + expectedErrMSG string + expectedComparisonFunc require.ComparisonAssertionFunc + expectedComparisonMSG string + }{ + { + name: "Testing No Device Flow Config", + inputFlow: nil, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Invalid Device Flow Provider Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "NoNe", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Full Device Flow Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "hosted", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ + Provider: 0, + ProviderConfig: &mgmtProto.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.NoError, + expectedErrMSG: "should not return error", + expectedComparisonFunc: require.Equal, + expectedComparisonMSG: "should match", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mgmtServer := &Server{ + wgKey: testingServerKey, + config: &config.Config{ + DeviceAuthorizationFlow: testCase.inputFlow, + }, + } + + message := &mgmtProto.DeviceAuthorizationFlowRequest{} + + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + require.NoError(t, err, "should be able to encrypt message") + + resp, err := mgmtServer.GetDeviceAuthorizationFlow( + context.TODO(), + &mgmtProto.EncryptedMessage{ + WgPubKey: testingClientKey.PublicKey().String(), + Body: encryptedMSG, + }, + ) + testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) + if testCase.expectedComparisonFunc != nil { + flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} + + err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + require.NoError(t, err, "should be able to decrypt") + + testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) + testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) + } + }) + } +} diff --git a/management/server/token_mgr.go b/management/internals/shared/grpc/token_mgr.go similarity index 93% rename from management/server/token_mgr.go rename to management/internals/shared/grpc/token_mgr.go index f9293e7a8..e9770db41 100644 --- a/management/server/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct { relayCfg *nbconfig.Relay turnHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator - updateManager *PeersUpdateManager + updateManager network_map.PeersUpdateManager settingsManager settings.Manager groupsManager groups.Manager turnCancelMap map[string]chan struct{} @@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct { type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/server/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go similarity index 94% rename from management/server/token_mgr_test.go rename to management/internals/shared/grpc/token_mgr_test.go index 5c956dc31..06d28d05b 100644 --- a/management/server/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -13,6 +13,8 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) rc := &config.Relay{ Addresses: []string{"localhost:0"}, @@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { ttl := util.Duration{Duration: 2 * time.Second} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" updateChannel := peersManager.CreateChannel(context.Background(), peer) @@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { t.Errorf("expecting peer to be present in the relay cancel map, got not present") } - var updates []*UpdateMessage + var updates []*network_map.UpdateMessage loop: for timeout := time.After(5 * time.Second); ; { @@ -185,7 +187,7 @@ loop: func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" rc := &config.Relay{ diff --git a/management/server/account.go b/management/server/account.go index f5a5c7b7a..a4b2a752b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,10 +11,8 @@ import ( "reflect" "regexp" "slices" - "strconv" "strings" "sync" - "sync/atomic" "time" cacheStore "github.com/eko/gocache/lib/v4/store" @@ -26,6 +24,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -53,9 +52,6 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" - - envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -71,7 +67,7 @@ type DefaultAccountManager struct { cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded cacheLoading map[string]chan struct{} - peersUpdateManager *PeersUpdateManager + networkMapController network_map.Controller idpManager idp.Manager cacheManager *nbcache.AccountUserDataCache externalCacheManager nbcache.UserDataCache @@ -91,8 +87,7 @@ type DefaultAccountManager struct { singleAccountMode bool // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string - // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + peerLoginExpiry Scheduler peerInactivityExpiry Scheduler @@ -106,19 +101,11 @@ type DefaultAccountManager struct { permissionsManager permissions.Manager - accountUpdateLocks sync.Map - updateAccountPeersBufferInterval atomic.Int64 - - loginFilter *loginFilter - disableDefaultPolicy bool - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} } +var _ account.Manager = (*DefaultAccountManager)(nil) + func isUniqueConstraintError(err error) bool { switch { case strings.Contains(err.Error(), "(SQLSTATE 23505)"), @@ -185,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] func BuildManager( ctx context.Context, store store.Store, - peersUpdateManager *PeersUpdateManager, + networkMapController network_map.Controller, idpManager idp.Manager, singleAccountModeDomain string, - dnsDomain string, eventStore activity.Store, geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, @@ -204,27 +190,14 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - am := &DefaultAccountManager{ Store: store, geo: geo, - peersUpdateManager: peersUpdateManager, + networkMapController: networkMapController, idpManager: idpManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(), @@ -235,15 +208,10 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, - loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, - holder: types.NewHolder(), - - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, } - am.startWarmup(ctx) + am.networkMapController.StartWarmup(ctx) accountsCounter, err := store.GetAccountsCounter(ctx) if err != nil { @@ -291,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { am.ephemeralManager = em } -func (am *DefaultAccountManager) startWarmup(ctx context.Context) { - var initialInterval int64 - intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") - interval, err := strconv.Atoi(intervalStr) - if err != nil { - initialInterval = 1 - log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) - } else { - initialInterval = int64(interval) * 10 - go func() { - startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") - startupPeriod, err := strconv.Atoi(startupPeriodStr) - if err != nil { - startupPeriod = 1 - log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) - } - time.Sleep(time.Duration(startupPeriod) * time.Second) - am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) - }() - } - am.updateAccountPeersBufferInterval.Store(initialInterval) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) - -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -419,9 +361,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } go am.UpdateAccountPeers(ctx, accountID) } @@ -1504,10 +1443,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { - return err - } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -1667,14 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { - return am.loginFilter.allowLogin(wgPubKey, metahash) -} - -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { - return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) + return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) @@ -1682,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - metahash := metaHash(meta, realIP.String()) - am.loginFilter.addLogin(peerPubKey, metahash) - - return peer, netMap, postureChecks, nil + return peer, netMap, postureChecks, dnsfwdPort, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { @@ -1702,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) + _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { - return mapError(ctx, err) + return err } return nil } -// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() -func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { - return am.peersUpdateManager.GetAllConnectedPeers(), nil -} - -// HasConnectedChannel returns true if peers has channel in update manager, otherwise false -func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { - return am.peersUpdateManager.HasChannel(peerID) -} - var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { return invalidDomainRegexp.MatchString(domain) } -// GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { - if settings == nil { - return am.dnsDomain - } - if settings.DNSDomain == "" { - return am.dnsDomain - } - - return settings.DNSDomain -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { peers := []*nbpeer.Peer{} log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID) @@ -2159,8 +2065,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(peer.AccountID, peer) } return nil } @@ -2208,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti if err != nil { return fmt.Errorf("get account settings: %w", err) } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() diff --git a/management/server/account/manager.go b/management/server/account/manager.go index db377865a..7c174a481 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -89,7 +89,6 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) @@ -97,10 +96,8 @@ type Manager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) - LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - GetAllConnectedPeers() (map[string]struct{}, error) - HasConnectedChannel(peerID string) bool + LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) @@ -110,7 +107,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -127,6 +124,4 @@ type Manager interface { GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) - AllowSync(string, uint64) bool - RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go new file mode 100644 index 000000000..eced1929f --- /dev/null +++ b/management/server/account/request_buffer.go @@ -0,0 +1,11 @@ +package account + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/types" +) + +type RequestBuffer interface { + GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 200ba6b98..ee9950796 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -22,6 +22,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) { } func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) @@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" _ = newAccountWithId(context.Background(), "", userId, domain, false) - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") @@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { } func TestAccountManager_PrivateAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } func TestAccountManager_SetOrUpdateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } func TestAccountManager_GetAccountByUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string) } func TestAccountManager_GetAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } func TestAccountManager_DeleteAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { DomainCategory: types.PublicCategory, } - am, err := createManager(b) + am, _, err := createManager(b) if err != nil { b.Fatal(err) return @@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User { } func TestAccountManager_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) { } func TestAccountManager_AddPeerWithUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1155,7 +1158,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { } func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1164,7 +1167,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1190,8 +1193,8 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1215,7 +1218,7 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1224,10 +1227,10 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { - manager, account, peer1, _, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) // Ensure that we do not receive an update message before the policy is deleted time.Sleep(time.Second) @@ -1258,7 +1261,7 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1267,7 +1270,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { - manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ AccountID: account.Id, @@ -1280,8 +1283,8 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1316,7 +1319,7 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1325,7 +1328,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { - manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1354,8 +1357,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + // We need to sleep to wait for the buffer peer update + time.Sleep(300 * time.Millisecond) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1378,7 +1384,7 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1387,10 +1393,10 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1457,7 +1463,7 @@ func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func TestAccountManager_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1538,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy } func TestGetUsersFromAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1837,7 +1843,7 @@ func hasNilField(x interface{}) error { } func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1852,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1908,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { } func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1951,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2013,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test } func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2677,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", "postgres") - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") // create a new account @@ -2919,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { // Fatalf(format string, args ...interface{}) // } -func createManager(t testing.TB) (*DefaultAccountManager, error) { +func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - return nil, err + return nil, nil, err } ctrl := gomock.NewController(t) @@ -2948,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { - return nil, err + return nil, nil, err } - return manager, nil + return manager, updateManager, nil } func createStore(t testing.TB) (store.Store, error) { @@ -2982,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() - manager, err := createManager(t) + manager, updateManager, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3026,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, peer2 := getPeer(manager, setupKey) peer3 := getPeer(manager, setupKey) - return manager, account, peer1, peer2, peer3 + return manager, updateManager, account, peer1, peer2, peer3 } -func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -3039,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag } } -func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3077,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3086,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) assert.NoError(b, err) } @@ -3140,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3149,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3210,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3219,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3282,7 +3289,7 @@ func TestMain(m *testing.M) { } func Test_GetCreateAccountByPrivateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3328,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) { } func Test_UpdateToPrimaryAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3358,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { } func TestDefaultAccountManager_IsCacheCold(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) t.Run("memory cache", func(t *testing.T) { @@ -3408,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { } func TestPropagateUserGroupMemberships(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) ctx := context.Background() @@ -3525,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3557,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3596,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -3663,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { } func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3709,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { } func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/dns.go b/management/server/dns.go index decc5175d..baf6debc3 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,54 +3,23 @@ package server import ( "context" "slices" - "sync" log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) const ( dnsForwarderPort = nbdns.ForwarderServerPort - oldForwarderPort = nbdns.ForwarderClientPort ) -const dnsForwarderPortMinVersion = "v0.59.0" - -// DNSConfigCache is a thread-safe cache for DNS configuration components -type DNSConfigCache struct { - NameServerGroups sync.Map -} - -// GetNameServerGroup retrieves a cached name server group -func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { - if c == nil { - return nil, false - } - if value, ok := c.NameServerGroups.Load(key); ok { - return value.(*proto.NameServerGroup), true - } - return nil, false -} - -// SetNameServerGroup stores a name server group in the cache -func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { - if c == nil { - return - } - c.NameServerGroups.Store(key, value) -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) @@ -117,9 +86,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -194,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } - -// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. -// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. -func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { - if len(peers) == 0 { - return int64(oldForwarderPort) - } - - reqVer := semver.Canonical(requiredVersion) - - // Check if all peers have the required version or newer - for _, peer := range peers { - - // Development version is always supported - if peer.Meta.WtVersion == "development" { - continue - } - peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) - if peerVersion == "" { - // If any peer doesn't have version info, return 0 - return int64(oldForwarderPort) - } - - // Compare versions - if semver.Compare(peerVersion, reqVer) < 0 { - return int64(oldForwarderPort) - } - } - - // All peers have the required version or newer - return int64(dnsForwarderPort) -} - -// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { - protoUpdate := &proto.DNSConfig{ - ServiceEnable: update.ServiceEnable, - CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), - NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), - ForwarderPort: forwardPort, - } - - for _, zone := range update.CustomZones { - protoZone := convertToProtoCustomZone(zone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } - - for _, nsGroup := range update.NameServerGroups { - cacheKey := nsGroup.ID - if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) - } else { - protoGroup := convertToProtoNameServerGroup(nsGroup) - cache.SetNameServerGroup(cacheKey, protoGroup) - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) - } - } - - return protoUpdate -} - -// Helper function to convert nbdns.CustomZone to proto.CustomZone -func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { - protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), - } - for _, record := range zone.Records { - protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ - Name: record.Name, - Type: int64(record.Type), - Class: record.Class, - TTL: int64(record.TTL), - RData: record.RData, - }) - } - return protoZone -} - -// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup -func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { - protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, - SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, - NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), - } - for _, ns := range nsGroup.NameServers { - protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ - IP: ns.IP.String(), - Port: int64(ns.Port), - NSType: int64(ns.NSType), - }) - } - return protoGroup -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 96f73a390..356a2f640 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "net/netip" - "reflect" "testing" "time" @@ -12,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { @@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return am.Store.GetAccount(context.Background(), account.Id) } -func generateTestData(size int) nbdns.Config { - config := nbdns.Config{ - ServiceEnable: true, - CustomZones: make([]nbdns.CustomZone, size), - NameServerGroups: make([]*nbdns.NameServerGroup, size), - } - - for i := 0; i < size; i++ { - config.CustomZones[i] = nbdns.CustomZone{ - Domain: fmt.Sprintf("domain%d.com", i), - Records: []nbdns.SimpleRecord{ - { - Name: fmt.Sprintf("record%d", i), - Type: 1, - Class: "IN", - TTL: 3600, - RData: "192.168.1.1", - }, - }, - } - - config.NameServerGroups[i] = &nbdns.NameServerGroup{ - ID: fmt.Sprintf("group%d", i), - Primary: i == 0, - Domains: []string{fmt.Sprintf("domain%d.com", i)}, - SearchDomainsEnabled: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("8.8.8.8"), - Port: 53, - NSType: 1, - }, - }, - } - } - - return config -} - -func BenchmarkToProtocolDNSConfig(b *testing.B) { - sizes := []int{10, 100, 1000} - - for _, size := range sizes { - testData := generateTestData(size) - - b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { - cache := &DNSConfigCache{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - - b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - } -} - -func TestToProtocolDNSConfigWithCache(t *testing.T) { - var cache DNSConfigCache - - // Create two different configs - config1 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.com", - Records: []nbdns.SimpleRecord{ - {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group1", - Name: "Group 1", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, - }, - }, - }, - } - - config2 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.org", - Records: []nbdns.SimpleRecord{ - {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group2", - Name: "Group 2", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, - }, - }, - }, - } - - // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) - - // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Verify that result1 and result3 are identical - if !reflect.DeepEqual(result1, result3) { - t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) - } - - // Verify that result2 is different from result1 and result3 - if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { - t.Errorf("Results should be different for different inputs") - } - - if _, exists := cache.GetNameServerGroup("group1"); !exists { - t.Errorf("Cache should contain name server group 'group1'") - } - - if _, exists := cache.GetNameServerGroup("group2"); !exists { - t.Errorf("Cache should contain name server group 'group2'") - } -} - -func TestComputeForwarderPort(t *testing.T) { - // Test with empty peers list - peers := []*nbpeer.Peer{} - result := computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) - } - - // Test with peers that have old versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.26.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have new versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(dnsForwarderPort) { - t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) - } - - // Test with peers that have mixed versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have empty version - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) - } - - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "development", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result == int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) - } - - // Test with peers that have unknown version string - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "unknown", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) - } -} - func TestDNSAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { @@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates diff --git a/management/server/event_test.go b/management/server/event_test.go index 8c56fd3f6..420e69866 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac } func TestDefaultAccountManager_GetEvents(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { return } diff --git a/management/server/group.go b/management/server/group.go index 3cf9290a2..84e641f26 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,9 +114,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -185,9 +182,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -256,9 +250,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -327,9 +318,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -376,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) return nil } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peerID := range addedPeers { peer, ok := peers[peerID] @@ -493,9 +481,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -534,9 +519,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -565,9 +547,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -606,9 +585,6 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group_test.go b/management/server/group_test.go index 31ff29cbc..4935dac5d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -37,7 +37,7 @@ const ( ) func TestDefaultAccountManager_CreateGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Fatalf("failed to create account manager: %s", err) } @@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroups(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) assert.NoError(t, err, "Failed to create account manager") manager, account, err := initTestGroupAccount(am) @@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t } func TestGroupAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving a group that is not linked to any resource should not update account peers @@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } func Test_AddPeerToGroup(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) { } func Test_AddPeerToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) { } func Test_AddPeerAndAddToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP { } func Test_IncrementNetworkSerial(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return diff --git a/management/server/holder.go b/management/server/holder.go deleted file mode 100644 index e8a26e1d0..000000000 --- a/management/server/holder.go +++ /dev/null @@ -1,39 +0,0 @@ -package server - -import ( - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { - a := am.holder.GetAccount(account.Id) - if a == nil { - am.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - account.NetworkMapCache.UpdateAccountPointer(account) - am.holder.AddAccount(account) -} - -func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { - return am.holder.GetAccount(accountID) -} - -func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { - a := am.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { - am.holder.AddAccount(account) -} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4d2c224b4..c1a8c5885 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" @@ -65,6 +66,7 @@ func NewAPIHandler( permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, + networkMapController network_map.Controller, ) (http.Handler, error) { var rateLimitingConfig *middleware.RateLimiterConfig @@ -120,7 +122,7 @@ func NewAPIHandler( } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index df89c616c..c4c5ae165 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -23,11 +24,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager account.Manager + accountManager account.Manager + networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - peersHandler := NewHandler(accountManager) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { + peersHandler := NewHandler(accountManager, networkMapController) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { return &Handler{ - accountManager: accountManager, + accountManager: accountManager, + networkMapController: networkMapController, } } @@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // This may happen after server restart when not all peers are yet connected - if !h.accountManager.HasConnectedChannel(peer.ID) { + if !h.networkMapController.IsConnected(peer.ID) { peerToReturn.Status.Connected = false } } @@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain(account.Settings) + dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 94564113f..7a5a6d911 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -14,12 +14,14 @@ import ( "time" "github.com/gorilla/mux" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,7 +38,7 @@ const ( serviceUser = "service_user" ) -func initTestMetaData(peers ...*nbpeer.Peer) *Handler { +func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { @@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } + ctrl := gomock.NewController(t) + + networkMapController := network_map.NewMockController(ctrl) + networkMapController.EXPECT(). + GetDNSDomain(gomock.Any()). + Return("domain"). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(noUpdateChannelTestPeerID). + Return(false). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(gomock.Any()). + Return(true). + AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return account.Settings, nil }, }, + networkMapController: networkMapController, } } @@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer, peer1) + p := initTestMetaData(t, peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) { UserID: regularUser, } - p := initTestMetaData(peer1, peer2, peer3) + p := initTestMetaData(t, peer1, peer2, peer3) tt := []struct { name string @@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { }, } - p := initTestMetaData(testPeer) + p := initTestMetaData(t, testPeer) tt := []struct { name string diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index bdf56db6e..ab3f5437a 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" @@ -31,7 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/users" ) -func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) if err != nil { t.Fatalf("Failed to create test store: %v", err) @@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create metrics: %v", err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) done := make(chan struct{}) if validateUpdate { @@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + + ctx := context.Background() + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) + am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } @@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee return apiHandler, am, done } -func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server } } -func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) { t.Helper() select { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index a34d2086b..fc67e01af 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -22,10 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ return loginResp, nil } -func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { - testingServerKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testingClientKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testCases := []struct { - name string - inputFlow *config.DeviceAuthorizationFlow - expectedFlow *mgmtProto.DeviceAuthorizationFlow - expectedErrFunc require.ErrorAssertionFunc - expectedErrMSG string - expectedComparisonFunc require.ComparisonAssertionFunc - expectedComparisonMSG string - }{ - { - name: "Testing No Device Flow Config", - inputFlow: nil, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Invalid Device Flow Provider Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "NoNe", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Full Device Flow Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "hosted", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ - Provider: 0, - ProviderConfig: &mgmtProto.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.NoError, - expectedErrMSG: "should not return error", - expectedComparisonFunc: require.Equal, - expectedComparisonMSG: "should match", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - mgmtServer := &GRPCServer{ - wgKey: testingServerKey, - config: &config.Config{ - DeviceAuthorizationFlow: testCase.inputFlow, - }, - } - - message := &mgmtProto.DeviceAuthorizationFlowRequest{} - - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) - require.NoError(t, err, "should be able to encrypt message") - - resp, err := mgmtServer.GetDeviceAuthorizationFlow( - context.TODO(), - &mgmtProto.EncryptedMessage{ - WgPubKey: testingClientKey.PublicKey().String(), - Body: encryptedMSG, - }, - ) - testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) - if testCase.expectedComparisonFunc != nil { - flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) - require.NoError(t, err, "should be able to decrypt") - - testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) - testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) - } - }) - } -} - func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") @@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config t.Fatal(err) } - peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck @@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() - accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { @@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } @@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) { peerLogin := types.PeerLogin{ WireGuardPubKey: key.String(), SSHKey: "random", - Meta: extractPeerMeta(context.Background(), meta), - SetupKey: setupKey.Key, - ConnectionIP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: meta.GetHostname(), + GoOS: meta.GetGoOS(), + Kernel: meta.GetKernel(), + Platform: meta.GetPlatform(), + OS: meta.GetOS(), + OSVersion: meta.GetOSVersion(), + WtVersion: meta.GetNetbirdVersion(), + UIVersion: meta.GetUiVersion(), + KernelVersion: meta.GetKernelVersion(), + SystemSerialNumber: meta.GetSysSerialNumber(), + SystemProductName: meta.GetSysProductName(), + SystemManufacturer: meta.GetSysManufacturer(), + Environment: nbpeer.Environment{ + Cloud: meta.GetEnvironment().GetCloud(), + Platform: meta.GetEnvironment().GetPlatform(), + }, + Flags: nbpeer.Flags{ + RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(), + RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(), + ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(), + DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(), + DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(), + DisableDNS: meta.GetFlags().GetDisableDNS(), + DisableFirewall: meta.GetFlags().GetDisableFirewall(), + BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), + BlockInbound: meta.GetFlags().GetBlockInbound(), + LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + }, + }, + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, } login := func() error { diff --git a/management/server/management_test.go b/management/server/management_test.go index 1a5e47354..930ecfb5a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -20,7 +20,10 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" @@ -176,7 +179,6 @@ func startServer( log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) @@ -199,13 +201,18 @@ func startServer( AnyTimes() permissionsManager := permissions.NewManager(str) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(ctx, str) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager( context.Background(), str, - peersUpdateManager, + networkMapController, nil, "", - "netbird.selfhosted", eventStore, nil, false, @@ -220,18 +227,18 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer( - context.Background(), + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - peersUpdateManager, + updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, + networkMapController, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8baffa58b..781d84f5f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -94,7 +94,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error @@ -178,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { @@ -747,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, accountID) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index ee77a65bb..f278e1761 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,9 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -137,9 +134,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -183,9 +177,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 6c985410c..35291b30c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { @@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) { } func TestNameServerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup @@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a nameserver group with a distribution group no peers should not update account peers diff --git a/management/server/networkmap.go b/management/server/networkmap.go deleted file mode 100644 index 2a0627643..000000000 --- a/management/server/networkmap.go +++ /dev/null @@ -1,80 +0,0 @@ -package server - -import ( - "context" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - nbdns "github.com/netbirdio/netbird/dns" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - am.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (am *DefaultAccountManager) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - customZone nbdns.CustomZone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := am.getAccountFromHolderOrInit(accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) -} - -func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerAddedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := am.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - am.updateAccountInHolder(account) -} - -func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if am.experimentalNetworkMap(accountId) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - am.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { - _, ok := am.expNewNetworkMapAIDs[accountId] - return am.expNewNetworkMap || ok -} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 0e6d1631b..b6706ca45 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,9 +177,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index b740610c2..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,9 +157,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -260,9 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -337,9 +331,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 89ac419fd..82cac424a 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,9 +119,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -186,9 +183,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -223,9 +217,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 4c605b5eb..cd9fbe4c8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -8,8 +8,6 @@ import ( "net" "slices" "strings" - "sync" - "sync/atomic" "time" "github.com/rs/xid" @@ -23,7 +21,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -31,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -140,12 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } return nil @@ -201,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var peer *nbpeer.Peer var settings *types.Settings var peerGroupList []string - var requiresPeerUpdates bool var peerLabelChanged bool var sshChanged bool var loginExpirationChanged bool @@ -224,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - dnsDomain = am.GetDNSDomain(settings) + dnsDomain = am.networkMapController.GetDNSDomain(settings) - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) + update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -319,15 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - - if peerLabelChanged || requiresPeerUpdates { - am.UpdateAccountPeers(ctx, accountID) - } else if sshChanged { - am.UpdateAccountPeer(ctx, accountID, peer.ID) - } + am.networkMapController.OnPeerUpdated(accountID, peer) return peer, nil } @@ -383,20 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - - if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) - } - + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) } - if userID != activity.SystemInitiator { - am.BufferUpdateAccountPeers(ctx, accountID) + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) } return nil @@ -404,47 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return nil, err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - groups := make(map[string][]string) - for groupID, group := range account.Groups { - groups[groupID] = group.Peers - } - - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(peer.AccountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return networkMap, nil + return am.networkMapController.GetNetworkMap(ctx, peerID) } // GetPeerNetwork returns the Network for a given peer @@ -703,27 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) - } + if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } - am.BufferUpdateAccountPeers(ctx, accountID) - - return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) + return p, nmap, pc, err } func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { @@ -738,7 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -748,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -798,17 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil }) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) + return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -933,15 +857,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - startBuffer := time.Now() - am.BufferUpdateAccountPeers(ctx, accountID) - log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + return p, nmap, pc, err } // getPeerPostureChecks returns the posture checks for the peer. @@ -1033,68 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, nil, nil, err - } - - emptyMap := &types.NetworkMap{ - Network: network.Copy(), - } - return peer, emptyMap, nil, nil - } - - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, nil, nil, err - } - - startPosture := time.Now() - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) - - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, nil, nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return peer, networkMap, postureChecks, nil -} - func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { @@ -1118,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return fmt.Errorf("failed to get account settings: %w", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings))) return nil } @@ -1214,232 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return - } - } - - globalStart := time.Now() - - hasPeersConnected := false - for _, peer := range account.Peers { - if am.peersUpdateManager.HasChannel(peer.ID) { - hasPeersConnected = true - break - } - - } - - if !hasPeersConnected { - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) - return - } - - var wg sync.WaitGroup - semaphore := make(chan struct{}, 10) - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - if am.experimentalNetworkMap(accountID) { - am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - for _, peer := range account.Peers { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) - continue - } - - wg.Add(1) - semaphore <- struct{}{} - go func(p *nbpeer.Peer) { - defer wg.Done() - defer func() { <-semaphore }() - - start := time.Now() - - postureChecks, err := am.getPeerPostureChecks(account, p.ID) - if err != nil { - log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) - return - } - - am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) - start = time.Now() - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) - start = time.Now() - - proxyNetworkMap, ok := proxyNetworkMaps[p.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - - peerGroups := account.GetPeerGroups(p.ID) - start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) - am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) - }(peer) - } - - // - - wg.Wait() - if am.metrics != nil { - am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart)) - } -} - -type bufferUpdate struct { - mu sync.Mutex - next *time.Timer - update atomic.Bool + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) - - bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) - b := bufUpd.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - am.UpdateAccountPeers(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - if b.next == nil { - b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { - am.UpdateAccountPeers(ctx, accountID) - }) - return - } - b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) - }() + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) - return - } - - peer := account.GetPeer(peerId) - if peer == nil { - log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) - return - } - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - postureChecks, err := am.getPeerPostureChecks(account, peerId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) - return - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountId) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - - extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) - return - } - - peerGroups := account.GetPeerGroups(peerId) - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1594,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err != nil { return nil, err } - dnsDomain := am.GetDNSDomain(settings) - - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, err - } - - dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { @@ -1635,24 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) @@ -1661,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return peerDeletedEvents, nil } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} - } - return labelMap -} - // validatePeerDelete checks if the peer can be deleted. func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error { linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e151f5abb..95c609595 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,7 +13,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -25,10 +24,14 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" @@ -172,12 +175,12 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { } func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testGetNetworkMapGeneral(t) } func testGetNetworkMapGeneral(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -249,7 +252,7 @@ func testGetNetworkMapGeneral(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { // TODO: disable until we start use policy again t.Skip() - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -426,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } func TestAccountManager_GetPeerNetwork(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -487,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { } func TestDefaultAccountManager_GetPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -674,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -742,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } -func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { +func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) { b.Helper() - manager, err := createManager(b) + manager, updateManager, err := createManager(b) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } accountID := "test_account" @@ -798,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ips := account.GetTakenIPs() peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } peerKey, _ := wgtypes.GeneratePrivateKey() @@ -904,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou err = manager.Store.SaveAccount(context.Background(), account) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } - return manager, accountID, regularUser, nil + return manager, updateManager, accountID, regularUser, nil } func BenchmarkGetPeers(b *testing.B) { @@ -928,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -968,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -980,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) - for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels - b.ResetTimer() start := time.Now() @@ -1013,7 +1012,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testUpdateAccountPeers(t) } @@ -1037,7 +1036,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) if err != nil { t.Fatalf("Failed to setup test account manager: %v", err) } @@ -1049,13 +1048,12 @@ func testUpdateAccountPeers(t *testing.T) { t.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]chan *network_map.UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels manager.UpdateAccountPeers(ctx, account.Id) for _, channel := range peerChannels { @@ -1097,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) { DNSLabel: "peer1", SSHKey: "peer1-ssh-key", } - turnRelayToken := &Token{ + turnRelayToken := &grpc.Token{ Payload: "turn-user", Signature: "turn-pass", } @@ -1177,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) { }, }, } - dnsCache := &DNSConfigCache{} + dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) + response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1289,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1369,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1517,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1566,7 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1592,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1725,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) { } func TestPeerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) @@ -1782,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) { var peer5 *nbpeer.Peer var peer6 *nbpeer.Peer - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) @@ -1890,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) { }) t.Run("validator requires no update", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") + requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { return update, false, nil } @@ -2091,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { } func Test_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2188,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2276,136 +2297,8 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, uint64(totalPeers), account.Network.Serial) } -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} - func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2442,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2476,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { } func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2541,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/policy.go b/management/server/policy.go index ff02d46aa..3e84c3d10 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" @@ -77,9 +76,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -123,9 +119,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -258,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin return validIDs } - -// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - fwRule := &proto.FirewallRule{ - PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, - } - - if shouldUsePortRange(fwRule) { - fwRule.PortInfo = rule.PortRange.ToProto() - } - - result[i] = fwRule - } - return result -} - -func shouldUsePortRange(rule *proto.FirewallRule) bool { - return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) -} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 97ebbcf5a..90fe8f036 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1135,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { } func TestPolicyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -1164,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var policyWithGroupRulesNoPeers *types.Policy diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index f457b994b..ac8ea35de 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,19 +2,15 @@ package server import ( "context" - "errors" - "fmt" "slices" "github.com/rs/xid" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -80,9 +76,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -139,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { - peerPostureChecks := make(map[string]*posture.Checks) - - if len(account.PostureChecks) == 0 { - return nil, nil - } - - for _, policy := range account.Policies { - if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { - continue - } - - if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { - return nil, err - } - } - - return maps.Values(peerPostureChecks), nil -} - // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) @@ -214,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account return nil } -// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) - if err != nil { - return err - } - - if !isInGroup { - return nil - } - - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.GetPostureChecks(sourcePostureCheckID) - if postureCheck == nil { - return errors.New("failed to add policy posture checks: posture checks not found") - } - peerPostureChecks[sourcePostureCheckID] = postureCheck - } - - return nil -} - -// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, sourceGroup := range rule.Sources { - group := account.GetGroup(sourceGroup) - if group == nil { - return false, fmt.Errorf("failed to check peer in policy source group: group not found") - } - - if slices.Contains(group.Peers, peerID) { - return true, nil - } - } - } - - return false, nil -} - // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 67760d55a..13152ed12 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -21,7 +21,7 @@ const ( ) func TestDefaultAccountManager_PostureCheck(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er } func TestPostureCheckAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) postureCheckA := &posture.Checks{ @@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy where destination has peers but source does not // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + updateManager.CloseChannel(context.Background(), peer2.ID) }) _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ @@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } func TestArePostureCheckChangesAffectPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestPostureChecksAccount(manager) diff --git a/management/server/route.go b/management/server/route.go index 05f7acf9e..2b4f11d05 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -192,9 +191,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -249,9 +245,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -295,9 +288,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -381,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID return groupsMap, nil } -func toProtocolRoute(route *route.Route) *proto.Route { - return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, - SkipAutoApply: route.SkipAutoApply, - } -} - -func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0, len(routes)) - for _, r := range routes { - protoRoutes = append(protoRoutes, toProtocolRoute(r)) - } - return protoRoutes -} - // getPlaceholderIP returns a placeholder IP address for the route if domains are used func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { - result := make([]*proto.RouteFirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - result[i] = &proto.RouteFirewallRule{ - SourceRanges: rule.SourceRanges, - Action: getProtoAction(rule.Action), - Destination: rule.Destination, - Protocol: getProtoProtocol(rule.Protocol), - PortInfo: getProtoPortInfo(rule), - IsDynamic: rule.IsDynamic, - Domains: rule.Domains.ToPunycodeList(), - PolicyID: []byte(rule.PolicyID), - RouteID: string(rule.RouteID), - } - } - - return result -} - -// getProtoDirection converts the direction to proto.RuleDirection. -func getProtoDirection(direction int) proto.RuleDirection { - if direction == types.FirewallRuleDirectionOUT { - return proto.RuleDirection_OUT - } - return proto.RuleDirection_IN -} - -// getProtoAction converts the action to proto.RuleAction. -func getProtoAction(action string) proto.RuleAction { - if action == string(types.PolicyTrafficActionDrop) { - return proto.RuleAction_DROP - } - return proto.RuleAction_ACCEPT -} - -// getProtoProtocol converts the protocol to proto.RuleProtocol. -func getProtoProtocol(protocol string) proto.RuleProtocol { - switch types.PolicyRuleProtocolType(protocol) { - case types.PolicyRuleProtocolALL: - return proto.RuleProtocol_ALL - case types.PolicyRuleProtocolTCP: - return proto.RuleProtocol_TCP - case types.PolicyRuleProtocolUDP: - return proto.RuleProtocol_UDP - case types.PolicyRuleProtocolICMP: - return proto.RuleProtocol_ICMP - default: - return proto.RuleProtocol_UNKNOWN - } -} - -// getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { - var portInfo proto.PortInfo - if rule.Port != 0 { - portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} - } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { - portInfo.PortSelection = &proto.PortInfo_Range_{ - Range: &proto.PortInfo_Range{ - Start: uint32(portRange.Start), - End: uint32(portRange.End), - }, - } - } - return &portInfo -} - // areRouteChangesAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers. func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index 388db140c..27fe033c8 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) { Enabled: true, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } -func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { +func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createRouterStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} @@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + if err != nil { + return nil, nil, err + } + return am, updateManager, nil } func createRouterStore(t *testing.T) (store.Store, error) { @@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi } func TestRouteAccountPeersUpdate(t *testing.T) { - manager, err := createRouterManager(t) + manager, updateManager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestRouteAccount(t, manager) @@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) { require.NoError(t, err, "failed to create group %s", group.Name) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + updateManager.CloseChannel(context.Background(), peer1ID) }) // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index e55b33c94..bc361bbd7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -18,7 +18,7 @@ import ( ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } func TestGetSetupKeys(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) { } func TestSetupKeyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var setupKey *types.SetupKey @@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/user.go b/management/server/user.go index 66bea314f..be4e491a8 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if err != nil { return err } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) var peerIDs []string for _, peer := range peers { @@ -992,16 +992,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } + am.networkMapController.OnPeerUpdated(accountID, peer) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.DisconnectPeers(ctx, peerIDs) } return nil } @@ -1115,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var addPeerRemovedEvents []func() var updateAccountPeers bool + var userPeers []*nbpeer.Peer var targetUser *types.User var err error @@ -1124,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) + userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1147,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + for _, peer := range userPeers { + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) + } + + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) + } + } + for _, addPeerRemovedEvent := range addPeerRemovedEvents { addPeerRemovedEvent() } diff --git a/management/server/user_test.go b/management/server/user_test.go index 5920a2a33..69b8c85ee 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { } func TestDefaultAccountManager_SaveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a new regular user should not update account peers and not send peer update @@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + updateManager.CloseChannel(context.Background(), peer4.ID) }) // deleting user with linked peers should update account peers and send peer update @@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { } func TestApproveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) { } func TestRejectUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d4a9f1823..d3f341529 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,9 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -68,7 +71,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctrl := gomock.NewController(t) @@ -111,15 +113,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } groupsManager := groups.NewManagerMock() - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } From 6fb568728fa9d8806e662703f1597fe0e14ce2b6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:51:03 +0100 Subject: [PATCH 053/118] [management] Removed policy posture checks on original peer (#4779) Co-authored-by: crn4 --- management/server/types/networkmapbuilder.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 58f1bfa30..41aaa7fc8 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -257,8 +257,6 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, ) ([]*nbpeer.Peer, []*FirewallRule) { - ctx := context.Background() - peerID := peer.ID peerGroups := b.cache.peerToGroups[peerID] @@ -275,9 +273,6 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n for _, group := range peerGroups { policies := b.cache.groupToPolicies[group] for _, policy := range policies { - if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { - continue - } rules := b.cache.policyToRules[policy.ID] for _, rule := range rules { var sourcePeers, destinationPeers []*nbpeer.Peer From 27957036c9130caeb1cc490a8a994eb1fcad01dc Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:24:51 +0100 Subject: [PATCH 054/118] [client] Fix shutdown blocking on stuck ICE agent close (#4780) --- client/internal/peer/ice/agent.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index e80c98884..7b929c29d 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -22,6 +22,8 @@ const ( iceFailedTimeoutDefault = 6 * time.Second // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package iceRelayAcceptanceMinWaitDefault = 2 * time.Second + // iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete + iceAgentCloseTimeout = 3 * time.Second ) type ThreadSafeAgent struct { @@ -32,7 +34,17 @@ type ThreadSafeAgent struct { func (a *ThreadSafeAgent) Close() error { var err error a.once.Do(func() { - err = a.Agent.Close() + done := make(chan error, 1) + go func() { + done <- a.Agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } }) return err } From 3176b53968f326beb7ddd21b76eb45bfdf66cb28 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Thu, 13 Nov 2025 10:25:19 -0300 Subject: [PATCH 055/118] [client] Add quick actions window (#4717) * Open quick settings window if netbird-ui is already running * [client-ui] fix connection status comparison * [client-ui] modularize quick actions code * [client-ui] add netbird-disconnected logo * [client-ui] change quickactions UI It now displays the NetBird logo and a single button with a round icon * [client-ui] add hint message to quick actions screen This also updates fyne to v2.7.0 * [client-ui] remove unnecessary default clause * [client-ui] remove commented code * [client-ui] remove unused dependency * [client-ui] close quick actions on connection change * [client-ui] add function to get image from embed resources * [client] Return error when calling sendShowWindowSignal from Windows * [client-ui] Add commentary on empty OnTapped function for toggleConnectionButton * [client-ui] Fix tests * [client-ui] Add context to menuUpClick call * [client-ui] Pass serviceClient app as parameter To use its clipboard rather than the window's when showing the upload success dialog * [client-ui] Replace for select with for range chan * [client-ui] Replace settings change listener channel Settings now accept a function callback * [client-ui] Add missing iconAboutDisconnected to icons_windows.go * [client] Add quick actions signal handler for Windows with named events * [client] Run go mod tidy * [client] Remove line break * [client] Log unexpected status in separate function * [client-ui] Refactor quick actions window To address racing conditions, it also replaces usage of pause and resume channels with an atomic bool. * [client-ui] use derived context from ServiceClient * [client] Update signal_windows log message Also, format error when trying to set event on sendShowWindowSignal * go mod tidy * [client-ui] Add struct to pass fewer parameters to applyQuickActionsUiState function * [client] Add missing import --------- Co-authored-by: Viktor Liu --- client/ui/assets/netbird-disconnected.ico | Bin 0 -> 5056 bytes client/ui/assets/netbird-disconnected.png | Bin 0 -> 7537 bytes client/ui/client_ui.go | 70 +-- client/ui/debug.go | 8 +- client/ui/icons.go | 3 + client/ui/icons_windows.go | 3 + client/ui/quickactions.go | 349 +++++++++++++++ client/ui/quickactions_assets.go | 23 + client/ui/signal_unix.go | 76 ++++ client/ui/signal_windows.go | 171 +++++++ go.mod | 39 +- go.sum | 520 ++-------------------- 12 files changed, 734 insertions(+), 528 deletions(-) create mode 100644 client/ui/assets/netbird-disconnected.ico create mode 100644 client/ui/assets/netbird-disconnected.png create mode 100644 client/ui/quickactions.go create mode 100644 client/ui/quickactions_assets.go create mode 100644 client/ui/signal_unix.go create mode 100644 client/ui/signal_windows.go diff --git a/client/ui/assets/netbird-disconnected.ico b/client/ui/assets/netbird-disconnected.ico new file mode 100644 index 0000000000000000000000000000000000000000..812e9d283d096d823824fd66e6533359fb375b4a GIT binary patch literal 5056 zcmbtYi$7EU|39ci7|A8M&h;ZCN^PPP$}QB~%Kd&X z5v|!rNJVU@G}rMvKEFTVd-m94XRmY4>+*cP-p|+TeEE+o=GrUFHx3^L5htA5B=438qg#UQWPc{T#JKG-w#US$_RqN6XAC0Rg`8Vf zE84eI+8+kT?CY(*w6|8$v_58qt6?@yt`rRp4o+W0p+KAe`!VQB69F>lMHYEDU8EmB z7N@tePkbIZ{1AEdzEZ7BL{up;Ey@4@PDO64l+0{y6W^~71xI8AOk)V?GiS9KG!e-4 zwjjCp;E}0K3Bys`GmDPHa?G@xr-)fPk3C+p&GpsEgf&dW(Pj0RYnxJ7DEWCOedBaWq|c2bvwe+p6o9FF)4nvq zKCNMCvD4f>DK8AcQqF$g!Jl1wAen&vLg$#`0rKDmcRS+plHAfvew8*DY=&-Q;H)N! z4SG0urfykX;0$2Y6*6H4x(9;Ks%gaw*zCx-1Ft@;OnIe|`oPh#0CB;+ny;V~3-~IH}MFv~PgfWlpAHKd3FhEOs=&?6y z^0-#;tnI|q6urN_mT+&yl>D>p2u$}RJWgmfQsv3~#@#hj+(((pdoP;8@yS@CpHW>~ z`vH8ezSTEIE!*kgnIm_5p4^7a%+2*m32y~YpG*Iv_ZFA}Z5IsBrS;65dcS@G#l5JP z`jmP-gb8F4hjIl#-SDDQ6$f*-gYzvU%9 zE{yj>S@mQ=BSiyw1Mzw%3w2#vqqyV76CBh2BS8AyyLFyID?$^si8xve6t^Fr+bP!628V)ipPv2q0T+VR#CGb|F{A1Rhg^*|SA@Oz@>iB_**o% zpfBiko{N!-aLX`zq(pGi2$T3QJUpz0YNBvspB-^8bAf?K;{;q-MQJlxsI>6v|*J{;CB_i|a?cTp$9zi!hUZspw%`F+PgVm}^< z8h-a~@8c&=M!LH8B)UmQTw8u+o!()a6Pv+Ll@Ppu)6|Sw{tK!ow>RY|yYUiP?hz52 zc^+FVTZy)>7pDqX1epBN)>h==LcwsdxrIgFY81mXkxF7F)^P7x55ZWW4V(Dn3yc=Q z-l0{l?X_rhQt+)?#R$k2lz;Z0u@Q}Vrn8?=uKJGdk6`aM;CP(2x>jV`1w+;T%}VPTNxcY28R zZ{NzSc$o)C;)|ZDa>Z*mUc4T?AB%*I-V6z`>mNFZ|M3006e2-Nx&cJ-r(rx6UhJFQ zPH+ZePU^s(=0}JpNH;C3mt}ZCIpLajgSG{$V56ucw;#0x7hUAklh~_acMN7OJP2_7 z+@Y&vsduVixOfD*KO`1^#P$~Fr0n#tYkb?(Yb1}2r=%>6P2qD+d}}+xF_qo=HTBEV zL{K=h_~n1*;9YHLXrgs`b~_Y6y+HJR+tdQxXAC?Y_SU0NC^X8m)aWsb*hjKb0CIvC zGPu9k#V)Z$HFB~jHC2EyVCnJz1agE2K=G#7me02n#MIT*;XFZbtgOP;JK08f*l|H& z83H=Vr0%P|y?yb7|A0%9PoeO(=YvO(i`WCwc`05yPQ9>2DK52G5=zBsXs}ee>uK86 z*WBy4G@aDt^pz@YD}B*c+dAe%m+iB!KZ8B>QkaIE-}UAmw87lJBJfqf{l%+fs1F z8SMk~;5<%0M>(dZv*89}wWOBJrld`26p*ZFCtWVv9oeUN3L$2a4brtTgbcSyJXCUH zkRqsyq`paXECHp~5b~TF-0`8^cn31XbDp8SE_(9*ZI*|y@w|VFewF~BN0kje{&8r# z>a>D_F&rBqs=sV2zxB>q89=&~THcc(wWwa?4^2~&0y!>g@9yoY-#44BK(RqV=3*UU zwXUvCd%&er+=LGhAKZ@_S&iQD_!$|W(Vy7j)XxXxhx5fQEcng{#O%wyc*1#6EK3sv z@)QqXr5Ue-PI5j5=-1cRkCbEP^sQEOIpQNDC&04V6FV*_ti)Y>qnd!Mzru5LCA{ew zOTX~>_d_pnX@^ui>If(_vHI;>oRZFO^~fCNy@4~j!?B7H*NT@tmc632F0Qb5pZ z{{xL$dk3A^@{iex19kyI8^tq>Jfn}|uJ>dLy8ffK;!DgjYlN1YQ3Kf1%T>rSHoPfm z1`>p1V5;NM^pc*oJ4u=~*nZHV-2Oj@-spUvdpf{0NFkd%T z4{Gz{RjjG1`lCBW>css!SwZV_7cW2;f19*+=?1fS#*06#hv+fxge!Y<=mWipO|eY{ zbP8RNA8!5oz@&R7&|}HNWIDi}!~+apL9R(h9F^WW2u=wmDY1As5gev`4g$%^^%IND zgC+PEX^r;Adz{F0Z&j_{?Oc!ID_|strNyP-u62Z8veca7cNs%3eMG@0@}< zuoFZ_rvsL;JgdCukd_^dUcfU;$qw$dGcWu`y>)Fg1_8l5>n*{1Z`uVgO4`z1AGing4;(AB5Xw1;Kde~O(a~`;?R8yhvko-PSd?*_hSit?2G#sW&uFd07<^^WDBRyh)jN!&zVxV5{%R;b@cf00Pe4zjbwc=wp*%8lJR>jfSPCrvZ|{tfw> zJK%yCq&mRe(`0rAdrL8Zd|SxG74h)?b0=YtzZytjgUMtPJSxo&!U5B%+?L1dlW?At zw+g^0SYHW1#z%1UGrirXn;Ya)$LvI#$jK(>&nfIw41nz!3Ea4TrA{K@bJNlpX6-d1 z5vr>Lxh>e5nm-rMI#a?6!TfXf?~=w!J&9S}+_?lb-q$A8q+pftLgk7Io zy4L(@HLAX`F{|mD^XBwF$CE#BB9O}DH2HxIi}6a5TF%PE&B;EO8|8^^M(HyenN31= zAild(9MMm1U{y88%%{egMsc~+yhQgObT|YDL z*hb5d`S({hiy#2>LMf3Mx-(Yh@PrjoEGVjj2xKlz(6srobwbd@FRPO=cdK@I;DF3G zGw#wypBiru4}FfI^0C=QP6YR(&f)8^Gcz;gih?Z%xj_o2Pb+G@l)Pim5y9O!GafC2 zfjVLYJ4Q(@^5bkZhMI|@hqtPj5!VI}P$j;w=`stm0e*fMxbPtD<+ei0)(E9Z=n4kV zz%Ga@`5UB`J6SuTm0bd$*rvtz$O*Zlz)4nq{X6Gki?s)0ag?R#boR19Z@1fCO z!#(-8`T}zg{JTU*8|MQUAe=~;f52m*IDpg=HySQCP9rwEgL#%rcnWqhhG}MdVD9z; zy;760pr}t$0q7y;t%TVFE{4Z^eSK?kWjoP-j%kJ``0ntGsX>!%cFY(Hy=4>r2T)@^ zrO58cgOv|SI%CY}Yy8+b7Aq|dlhKql68OHUsY$b%i?AOs1>UcVFnp!kgP$55eSs*i zs1X0Fpv&gbVGx^3?f)skrzxea%Nm#SkQeH6a5dZ!zlL#8X8Ss&1pDY&O zCSD){-(OSCe#&-06+p=;yNIdfjQdH{$}fO0&d=S7rN+suZi-w}6&vs0=RPNrwo?a`L22A%*d57c;21=15}9?ktrkLheG# zynCP+AlDK?7Zw(*`BGd5DVMv2A;X*ji%;H41GGji7hJ%S$|s<&(}d4XnxkDDC;+e%G{H+S_Y0e2hGl|Fxq)#s;ORGgch^=Q)zZB=5dqa6?Dm zZxZ?lP8_)N=jFb6fxBuhhZ*3#DWk^}#0Xxf=E_c;HFtyRpPhX(Nl}{st0p2~4nCUx*N7L{Z1yuY-~ayB3e=i+|Lo;vZtUta M);HIyz`Mr%AJn`~GXMYp literal 0 HcmV?d00001 diff --git a/client/ui/assets/netbird-disconnected.png b/client/ui/assets/netbird-disconnected.png new file mode 100644 index 0000000000000000000000000000000000000000..79d4775eab038831d38abcc2845f671adc7a1b4c GIT binary patch literal 7537 zcmdsci9b~T_y23gsIe6yYs|=&!VoEmL6+>KvcwRw@9R6ujJ1e}A|<=X63Uu2(nhu{ zF@)@dDU7kq%y0C5e?IT`@BR3G|AFs)-1~T)d(L^D<=%78oO>tM$UuvQk(Utwz;a3Z zq6q*H8VLdL1GMG3Pti@<0&`T=R|TLlk!jDCo_3C~(>Bou;Fc%=xF`U2Xe`_U0RAWd zmTUk}d;|cOSH^2&C7K}0!R*paeSIK7qv3!a!Vc(Y6lC8Egcs~v+xMpf;rlOb0uld9 z2MWM#X8`+4$AY%*Z`!nDAM@u*mks&bVm9-1M`@XJ=gvu^WTjF;+zNNC^Df44v5^mD%XwW{?@z}-T=}+~rZAx}5xmxf%{Aqp;p3edADySMS^}Cak1ATi##q1N@zcllPBW9J zZ>V!m6g;1qWnq{jvCS_YUWMJng#6SHg7n5cHSkfzMG8cQF#SR0Yti4UszWACuC{cR zihwb2^#u!qQ?CNdL19E{jnZXu>w#6pr!r)?R-%CW_j8m9lKY~0Gg9vS$K>+noh6IZ zhjO_1C`vFg>Cj#SH|a@ZcIIsG8|4;wkY|rVg%BjEFvs5!7DqQR68*sYSolZ+pTePa z24V0^uFB&n?m)ADe%a|RX6!cepR5EBAMA%GgDce2uIUnaL$KaMzU>(0l!tpULmI6I zAY+@J<)x)cs@A-FXK(>M$G4ZuTUOH?jcD{W}@y?gft*|jD2g0=Vz~3Of&|gzWO}Wp2?`_1Sf8Q5&6Q z9O#^MT=Ey?NbN49_4%F&nG@bgiQvZrE@|uQjv6u@-cMUf`ya)g`8+}%O2g=H<#mdZJLg5P^r=GQx)9JANw5UN1jDe#eO{M z4s_`Vgq|*<;N6+dae-~=@@zT}mx@{qm$gPc1mZn)BM13m!LfX0=MzIri|ezPmXML| zf_IscKQiVHJB{l?8Fc*NQ;nXM=Ud0IOlmP#v6A7ett>aTXtmosTu?WkPxODIFjoIO zfWX~0sp6hRec4}f*j_kd$8(mJDPe{UcMSAo81|+Q-DRj+Q3=mLITi3|Z*H|D=<0m^ zv~sP>(SuHVy54&zEed_##0;Nt7u7CYKo;$ZSsce)rM8ZqrTS+{X=Jr9p}9B|Kl^{v ze{Emndr+)3J!Oj^KVDe(iW&uJzH&`SY3s{Uxd;O4^y;1T?x~WY>vSL3bHBLRTD{9= z*5;=&jMX;;VZWowIgQSKAaGr$3-sCLbUi-|UeYc4wXL{WQjJ+BXEHD7P2Oso4}BeC ze)2~HpT>%&X{@Mivz}K*A9$J4XB<$`y&!ENBuw}<+o|pZd%_B38%oiUaok{{BAOzZ_iX!iObxzSm~1VYKGDj-!=j z#CVLBX11I!gh0!!n-qEWN_hhzb(=ehvSt?mZD5syPxD;=*5{SKcK7+k*8+9Hr zk1Mw*Vj&cE2ts??5>X?|7%3f7ngW-eeE7kU8cMqe;m63U(YDLyvvFwC{82x}uY!=e z`LcR)T!Du$2x5JLv*Yi52~B8T{Zar+h{ts>*&G=8s3l~R@l{;lOAQuw)@zv=D=LzD zzc=b?VU+=`39<9If-an@RE|kM>b{rd-3gdmF`@s!EmzmVV(itsS81Nq;lg#VuD^F_ zbC(QbpY&Alu-&S%)G9>uLQhq8o}! zH-_6tZO@DEc`N&t9oE9p^$)BJ*Hn>6v2)B-(X~O=5s3?9n6cS_&NyJh#%c7{T|IrS z`@C;fpu-5go@J37*+bb{-I=lTyk$LP`4djO{Yfv3!{OPXnIwFkB@%plr|Yrn79N*!7&pzRQdYl-|o_OqXx|JRrB% zx2dq920=@Qr)t7(*0ikmE);9eTm*#g9MB++ukIGC1+UbzrQ0zn{U+Qp-MG|S(9qZ9 zI1z7A5QL=XIPpE8VZ3&#{F3`whD&PMk2yS(L-{Db3MxHdU5=S$C$qIR+s0K`u*xmX zoxOI~%B}qmBDyjT%IwV=pBn@(nJyjqpk24S^>a0=yi6`!DPenOGVw!`6EOGBVd%11 z*;B+1wH$-0H<0yQkX%z4Lz98n)O;=$M2Re2++@0=%5O39g5BlIQ=UJVn6bm5HoYal zCi+W}+HGZ1NJ$QGYU>VHA?;bLBT50YvqhyLo@nD}J&R(YGDDA+IHAkcOQsxV+Ja0AU2cE?n$HNTqCIaFbfGp?Z`wnlKa7EXdI<>Nhb z*ed!;*eZg&qmDlLE~(b6fRYsW5^y0ri5crJuliy6vg~uzO74CVH)^ASX>P+Zs}jGXrzLhs0+@+0K-N=dvn!9ybvAMkyR%hMrk0wQ z8GF41wm9cXoYNn9ctKj#wBbo=>mtInm@gL5Zaq#WT1rb}&dD#<SHxpYN-k; zYpEfWME8&J5tV}{%MuklT7)p3Fu?SzNUdYed~B{Y(frC>o_pLxtEtb=DUX zvjBn&LS_rZC)?K#OVfk!6UDtZYGXTAYRyW1m6=5yJ4LnTrTiLb;smGia~YdlC5k&_ z>dC*3Rk%EbN)w*9F|Tjd6^Awd*s=hu$u}7Ea?k7yvP30`!C83bO71!Te6$6{(F?p{ z%^p2gW}b2bSiHlNozJ%!_VnqH(B)vxt8Ua zt;M#p+>lnCXKFSry3y*23d9j_{n3a>m5Tbq4UYx_#_~Os^BI0|mdZpOogT-2DBmUZ zo_i%r(0{YJxyf)d?q)*>d25xHcyyauylXm3hJ$0_=R<1%-3sW zaNbYe;vOwG?C4JdGd4!Z7-ud!RCbTEVPWyictobbp<9!eAC+Ie(K#F;2R&NCsp}+D zT=)L{`<(iwxX)95_;zV#dPh>Hf>AFo`YM|f^b+%9utt#!a(*s4nImaa`?`uE7^?pDbRRm-Jw6pm=QqXBH4FVDWc<>a+1|Ji27~%W#V}el7}Rtf=#8#_R8(U zf~z*WxA&AWL(hq>&Dhm>&z5%peEZDdu6c~WlP5irAmtdM`e+KUe(2FuAyyw^pk zD(jz%Hvz2dqLW5P&dSmci8Rga2IgD~l^X4vwFBc}ua-(~huT*`ah0&0dlJI!Sq*Q# zFUjCAr2*HLrhEm;0AfcEZX-Wc(;4kjTw<{Fax7E z{aBV#yZiIS+d#7W#M5`*t~zYil049+>*P)|huwULy)Yrm8Ja|SUqAlzs{htnRL_F_ z-L(T3 zXw`6?eZl1*j*vLIa?M^EVkDgG5tu+LBFu1uu|{82Lk#-;7Bm;WekF3g3(0^xt!)sP zUpz@%%$#%e*_C&6I4u$_<#?dRESL6xnFKQ6M5;`XUXGVuMz_oENlq?=<|P~jIhBG- z!=bN1DqjWiLP+wS0Q1GbG?Bt7L)+$5W1rx?{sP^U!^))A8KFGyUCtbBI&t>rUp;t@ zq5`p!?d(pr!8{50@~^V&Kto6#-x&qYXO4x*gJ4B_agfWj?hs)PZH+qG|4c1gt~hBd zECzK!y*v8z+2YpcBd6vRlY@H=-@YkH9yu`uDOo_;jA$T!!S&nXP_oH+Y$OnOdB`MjN{JdmnIBgHVl14Z&%d|vf4|CaMB6)d0ye$v%=?Ktx9u|toc6&0)1PCgm0^PcYQ6N&7+ z3e8qev7PL3JfH*j1B(j-vJg~2ErzWchQ6=ezCwHX^(Gjue>?$uK^riq)(cao7!XVA zYDj_^qJ08-|KmkyHpu@ds(w4?wEM$?j@=ottkP6g@w_i<%@GHZ(1@|MS{MbR56#x# zD=4N`EtxVxMqyRVy4eU2z_#+yqNvu^_2$}A{}mSVXW{BuXU-+OqQf)>^$9=kKPyD| zD3*nCtS3v5bm)q}WSd|jy(E{Sf>knLhn zxA<`c2a=X1SW94N7V+0)uON$*{4Y7x)$$DC2C$yY_&<1Gc{Xp~3?SYiK441%;&~CF zi}5Yn_%P0oiYs%wP!6cyHRnYG9RGHk;a0}Glt>s_!_v;{+q)d+4qMlT6EF_l)R|tPC2z_zeuC(#S*SQzzob5`q}DD7YyDPkSSuj9})%P#}SaEyohZ|6F0>t?UdE*9=LVT2{B+|2m-O) zTPnqd1$kFs!oWmUZtqR>z>}%+V%3*q4rokPUAMup&nnd-do#r4-tX+-M5F0uR)CE& z*_58+xfY}0=he?o6jF9f6)m*TilPhoDPg?P+r^QKOWZBzNxx=wuZqcd(x=NZ`h5 zJzm3p%Ih3mphxfgn25IZoY#Tt87Lhm{VE>|OT9H-G@-S^JJ)H#$L*kw1fCsH0pT)I zb-O);7g*VMuN#Wio#X*8g;2olITj|Gpu086oyXWmxNR7=JTj60>eQw47A%M>%AO79{F%OQ&KHMM}ZfdUmcj*3)Zk z!*SgkDGv7{?cYCdvx}90T`Mxh9&B>+gChKQ^^aBX3wP?l(M2piWM!#1MTxGoe9GcY&;)ObcGCRw^x%X_M1+QIGrz(OjwXAa<6Wd( z=0L2NAjvFXCFkB3d2KGFRWwqHaedIRBtea?6+y}o%%EHs*(g%$wUL`>xxLEK1nzuZ zaqJgUOOVK?y=#_@NxBa}aWO5GO1szmBYSfh`(l31edent7xNDMq7>1i6-kAu>GWyG zqQQg^v#sqT>Co}QmejZ*_@i&+EaL5c>dB^Cn=WlxIH7fT?Mt161^6E9y( z2cy}XBuvO@^0~9q#x6i^PZj(yzsac|pzK2X^0PS}I6*$i`yt#^Sce4>N<5s;up=!h z4>`cQ>1_O{+}7gwYa9X@v(1~k$a4kGcE1rWu2!4X7Q)iUB$`qrN4w>yMMqce2UAuL zjia+~9k%e|rEtu-)`{NH0UuTzzuhMtK#DQ5RA+ytam7h>9&(!Kdh+O&P_s1ZqLU11 zFXider>-lyFHSeG-_fK)JMs>8DqHa>W0fcPXQsjGyY2DrVcYlz_?~mKJ4_dmV1cME z_N}PA_}&oP7`k6E`$%XgzZgB6z2rl5a8$LZi zmuk$Ba~KJ>rs+O!ZRIts)Mr_p42~-lN>_( z)H|$m*kv98ix>nb(kyW~r_<}%v_l1qmHUSxwN^JH=?7<7JmTUE~xm$+F7#7mlaAIe#%+Ib5WQ`xx0NDk$u2# zOTlH>n%%c%a%stBwwhzEggrswU<3+!qnYa7_)2**<}stJ7#i8ydm{t&0K+kN(y)EP ze{H`Xph$sx8b!&R`@Nvhk&4=h-h$HFK%wmQB z5$q0ofR`@qj#?Z^aqP(y*>7kiGZJO+JGlQ-dj=~Bdt4U2cyIH?+x Date: Thu, 13 Nov 2025 20:16:45 +0100 Subject: [PATCH 056/118] [client] Use stdnet with a context to avoid DNS deadlocks (#4781) --- client/iface/iface_test.go | 21 ++-- client/iface/udpmux/mux.go | 6 +- client/internal/dns/server_test.go | 6 +- client/internal/engine_stdnet.go | 2 +- client/internal/engine_stdnet_android.go | 2 +- client/internal/engine_test.go | 6 +- client/internal/peer/guard/ice_monitor.go | 2 +- client/internal/peer/ice/agent.go | 5 +- client/internal/peer/ice/stdnet.go | 6 +- client/internal/peer/ice/stdnet_android.go | 10 +- client/internal/peer/worker_ice.go | 2 +- client/internal/relay/relay.go | 4 +- client/internal/routemanager/manager_test.go | 4 +- .../systemops/systemops_generic_test.go | 4 +- client/internal/stdnet/stdnet.go | 112 +++++++++++++++++- 15 files changed, 153 insertions(+), 39 deletions(-) diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index e890b30f3..6bbfeaa63 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -1,6 +1,7 @@ package iface import ( + "context" "fmt" "net" "net/netip" @@ -9,13 +10,13 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/stdnet" ) // keep darwin compatibility @@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) { func Test_CreateInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) wgIP := "10.99.99.1/32" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -166,7 +167,7 @@ func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) { func Test_UpdatePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.9/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.13/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) { peer2wgPort := 33200 keepAlive := 1 * time.Second - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) { guid = fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - newNet, err = stdnet.NewNet() + newNet, err = stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/iface/udpmux/mux.go b/client/iface/udpmux/mux.go index 319724926..c5d2de4a5 100644 --- a/client/iface/udpmux/mux.go +++ b/client/iface/udpmux/mux.go @@ -1,6 +1,7 @@ package udpmux import ( + "context" "fmt" "io" "net" @@ -12,8 +13,9 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) /* @@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() { if len(networks) > 0 { if m.params.Net == nil { var err error - if m.params.Net, err = stdnet.NewNet(); err != nil { + if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { m.params.Logger.Errorf("failed to get create network: %v", err) } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 451b83f92..d12070128 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { privKey, _ := wgtypes.GenerateKey() - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go index 9e171b0b2..1ebb5779c 100644 --- a/client/internal/engine_stdnet.go +++ b/client/internal/engine_stdnet.go @@ -7,5 +7,5 @@ import ( ) func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(e.config.IFaceBlackList) + return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList) } diff --git a/client/internal/engine_stdnet_android.go b/client/internal/engine_stdnet_android.go index 68a0ae719..de3c80bcf 100644 --- a/client/internal/engine_stdnet_android.go +++ b/client/internal/engine_stdnet_android.go @@ -3,5 +3,5 @@ package internal import "github.com/netbirdio/netbird/client/internal/stdnet" func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) + return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 15ac0a947..d15a07f9d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,7 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { MTU: iface.DefaultMTU, }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -977,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 0f22ee7b0..a201dd095 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { log.Debugf("Gathering ICE candidates") - agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) if err != nil { return false, fmt.Errorf("create ICE agent: %w", err) } diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 7b929c29d..79f68d279 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,6 +1,7 @@ package ice import ( + "context" "sync" "time" @@ -49,13 +50,13 @@ func (a *ThreadSafeAgent) Close() error { return err } -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { +func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList) if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } diff --git a/client/internal/peer/ice/stdnet.go b/client/internal/peer/ice/stdnet.go index 3ce83727e..685ed0363 100644 --- a/client/internal/peer/ice/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -3,9 +3,11 @@ package ice import ( + "context" + "github.com/netbirdio/netbird/client/internal/stdnet" ) -func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNet(ifaceBlacklist) +func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ctx, ifaceBlacklist) } diff --git a/client/internal/peer/ice/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go index 84c665e6f..5033ec1b9 100644 --- a/client/internal/peer/ice/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,7 +1,11 @@ package ice -import "github.com/netbirdio/netbird/client/internal/stdnet" +import ( + "context" -func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 5d8ebfe45..840fc9241 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -209,7 +209,7 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { - agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 693ea1f31..59be5b0a7 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index d2f02526c..3697545ae 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -6,7 +6,7 @@ import ( "net/netip" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stretchr/testify/require" @@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index d9b109beb..01916fbe3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -15,7 +15,7 @@ import ( "syscall" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen peerPrivateKey, err := wgtypes.GeneratePrivateKey() require.NoError(t, err) - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) require.NoError(t, err) opts := iface.WGIFaceOpts{ diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 4b031c05c..381886ac6 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -4,17 +4,28 @@ package stdnet import ( + "context" + "errors" "fmt" + "net" + "net/netip" "slices" + "strconv" "sync" "time" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + + "github.com/netbirdio/netbird/client/iface/netstack" ) -const updateInterval = 30 * time.Second +const ( + updateInterval = 30 * time.Second + dnsResolveTimeout = 30 * time.Second +) + +var errNoSuitableAddress = errors.New("no suitable address found") // Net is an implementation of the net.Net interface // based on functions of the standard net package. @@ -28,12 +39,19 @@ type Net struct { // mu is shared between interfaces and lastUpdate mu sync.Mutex + + // ctx is the context for network operations that supports cancellation + ctx context.Context } // NewNetWithDiscover creates a new StdNet instance. -func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { +func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client // so in android cli use pionDiscover @@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri } // NewNet creates a new StdNet instance. -func NewNet(disallowList []string) (*Net, error) { +func NewNet(ctx context.Context, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ iFaceDiscover: pionDiscover{}, interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } return n, n.UpdateInterfaces() } +// resolveAddr performs DNS resolution with context support and timeout. +func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return netip.AddrPort{}, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err) + } + if port < 0 || port > 65535 { + return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port) + } + + ipNet := "ip" + switch network { + case "tcp4", "udp4": + ipNet = "ip4" + case "tcp6", "udp6": + ipNet = "ip6" + } + + if host == "" { + addr := netip.IPv4Unspecified() + if ipNet == "ip6" { + addr = netip.IPv6Unspecified() + } + return netip.AddrPortFrom(addr, uint16(port)), nil + } + + ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host) + if err != nil { + return netip.AddrPort{}, err + } + + if len(addrs) == 0 { + return netip.AddrPort{}, errNoSuitableAddress + } + + return netip.AddrPortFrom(addrs[0], uint16(port)), nil +} + // UpdateInterfaces updates the internal list of network interfaces // and associated addresses filtering them by name. // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one @@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I } return result } + +// ResolveUDPAddr resolves UDP addresses with context support and timeout. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + switch network { + case "udp", "udp4", "udp6": + case "": + network = "udp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err} + } + + return net.UDPAddrFromAddrPort(addrPort), nil +} + +// ResolveTCPAddr resolves TCP addresses with context support and timeout. +func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + case "": + network = "tcp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err} + } + + return net.TCPAddrFromAddrPort(addrPort), nil +} From e4b41d0ad70676b3f3f4a18f621b5467bd1c509c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:25:00 +0100 Subject: [PATCH 057/118] [client] Replace ipset lib (#4777) * Replace ipset lib * Update .github/workflows/check-license-dependencies.yml Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Ignore internal licenses * Ignore dependencies from AGPL code * Use exported errors * Use fixed version --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .../workflows/check-license-dependencies.yml | 73 +++++++++- client/firewall/iptables/acl_linux.go | 135 ++++++++++++++---- client/firewall/iptables/router_linux.go | 50 +++++-- go.mod | 6 +- go.sum | 12 +- 5 files changed, 225 insertions(+), 51 deletions(-) diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d3da427b0..2a3e7d424 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -3,10 +3,19 @@ name: Check License Dependencies on: push: branches: [ main ] + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' pull_request: + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' jobs: - check-dependencies: + check-internal-dependencies: + name: Check Internal AGPL Dependencies runs-on: ubuntu-latest steps: @@ -33,9 +42,67 @@ jobs: if [ $FOUND_ISSUES -eq 1 ]; then echo "" echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages will change license and should not be imported by client or shared code" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" exit 1 else echo "" - echo "✅ All license dependencies are clean" + echo "✅ All internal license dependencies are clean" fi + + check-external-licenses: + name: Check External GPL/AGPL Licenses + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + cache: true + + - name: Install go-licenses + run: go install github.com/google/go-licenses@v1.6.0 + + - name: Check for GPL/AGPL licensed dependencies + run: | + echo "Checking for GPL/AGPL/LGPL licensed dependencies..." + echo "" + + # Check all Go packages for copyleft licenses, excluding internal netbird packages + COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true) + + if [ -n "$COPYLEFT_DEPS" ]; then + echo "Found copyleft licensed dependencies:" + echo "$COPYLEFT_DEPS" + echo "" + + # Filter out dependencies that are only pulled in by internal AGPL packages + INCOMPATIBLE="" + while IFS=',' read -r package url license; do + if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then + # Find ALL packages that import this GPL package using go list + IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") + + # Check if any importer is NOT in management/signal/relay + BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1) + + if [ -n "$BSD_IMPORTER" ]; then + echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" + INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n" + else + echo "✓ $package ($license) is only used by internal AGPL packages - OK" + fi + fi + done <<< "$COPYLEFT_DEPS" + + if [ -n "$INCOMPATIBLE" ]; then + echo "" + echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:" + echo -e "$INCOMPATIBLE" + exit 1 + fi + fi + + echo "✅ All external license dependencies are compatible with BSD-3-Clause" diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d78372c9e..5ccaf17ba 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -1,13 +1,14 @@ package iptables import ( + "errors" "fmt" "net" "slices" "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -40,19 +41,13 @@ type aclManager struct { } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { - m := &aclManager{ + return &aclManager{ iptablesClient: iptablesClient, wgIface: wgIface, entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), - } - - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - - return m, nil + }, nil } func (m *aclManager) init(stateManager *statemanager.Manager) error { @@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering( specs = append(specs, "-j", actionToStr(action)) if ipsetName != "" { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } // if ruleset already exists it means we already have the firewall rule // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. @@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering( }}, nil } - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %s before use it: %s", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %s before use: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %s before use: %v", ipsetName, err) + } } - if err := ipset.Create(ipsetName); err != nil { - return nil, fmt.Errorf("failed to create ipset: %w", err) + if err := m.createIPSet(ipsetName); err != nil { + return nil, fmt.Errorf("create ipset: %w", err) } - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } ipList := newIpList(ip.String()) @@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("invalid rule type") } + shouldDestroyIpset := false if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { - if err := ipset.Del(r.ipsetName, r.ip); err != nil { - return fmt.Errorf("failed to delete ip from ipset: %w", err) + ip := net.ParseIP(r.ip) + if ip == nil { + return fmt.Errorf("parse IP %s", r.ip) + } + if err := m.delFromIPSet(r.ipsetName, ip); err != nil { + return fmt.Errorf("delete ip from ipset: %w", err) } delete(ipsetList.ips, r.ip) } @@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { // we delete last IP from the set, that means we need to delete // set itself and associated firewall rule too m.ipsetStore.deleteIpset(r.ipsetName) - - if err := ipset.Destroy(r.ipsetName); err != nil { - log.Errorf("delete empty ipset: %v", err) - } + shouldDestroyIpset = true } if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { @@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } + if shouldDestroyIpset { + if err := m.destroyIPSet(r.ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy empty ipset: %v", err) + } else { + log.Errorf("destroy empty ipset: %v", err) + } + } + } + m.updateState() return nil @@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error { } for _, ipsetName := range m.ipsetStore.ipsetNames() { - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + } } - if err := ipset.Destroy(ipsetName); err != nil { - log.Errorf("delete ipset %q during reset: %v", ipsetName, err) + if err := m.destroyIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("destroy ipset %q during reset: %v", ipsetName, err) + } } m.ipsetStore.deleteIpset(ipsetName) } @@ -368,8 +387,8 @@ func (m *aclManager) updateState() { // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { matchByIP := true - // don't use IP matching if IP is ip 0.0.0.0 - if ip.String() == "0.0.0.0" { + // don't use IP matching if IP is 0.0.0.0 + if ip.IsUnspecified() { matchByIP = false } @@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return ipsetName + actionSuffix } } + +func (m *aclManager) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (m *aclManager) addToIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add IP to ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) delFromIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + } + + if err := ipset.Del(name, entry); err != nil { + return fmt.Errorf("delete IP from ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) flushIPSet(name string) error { + return ipset.Flush(name) +} + +func (m *aclManager) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 305b0bf28..1fe4c149f 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -10,7 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/hashicorp/go-multierror" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" @@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 }, ) - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - return r, nil } @@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string { } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { - if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + if err := r.createIPSet(setName); err != nil { return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { - if err := ipset.AddPrefix(setName, prefix); err != nil { + if err := r.addPrefixToIPSet(setName, prefix); err != nil { return fmt.Errorf("add element to set %s: %w", setName, err) } } @@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error { } func (r *router) deleteIpSet(setName string) error { - if err := ipset.Destroy(setName); err != nil { + if err := r.destroyIPSet(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } @@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } - if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { @@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } + +func (r *router) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error { + addr := prefix.Addr() + ip := addr.AsSlice() + + entry := &ipset.Entry{ + IP: ip, + CIDR: uint8(prefix.Bits()), + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add prefix to ipset %s: %w", name, err) + } + + return nil +} + +func (r *router) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/go.mod b/go.mod index 14cd49192..2d7e0d31c 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.3.0 + github.com/vishvananda/netlink v1.3.1 golang.org/x/crypto v0.40.0 golang.org/x/sys v0.34.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 @@ -59,10 +59,10 @@ require ( github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 + github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -236,7 +236,7 @@ require ( github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.8 // indirect diff --git a/go.sum b/go.sum index ee5027d61..f4b62dff0 100644 --- a/go.sum +++ b/go.sum @@ -316,6 +316,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= @@ -356,8 +358,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= -github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= @@ -519,10 +519,10 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= -github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= From 0d79301141d9b856d05a000e1e5a21517f5aad5e Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 17 Nov 2025 15:28:20 +0100 Subject: [PATCH 058/118] Update client login success page (#4797) --- client/internal/auth/pkce_flow.go | 11 +- client/internal/templates/pkce-auth-msg.html | 155 ++++----- .../internal/templates/pkce_auth_msg_test.go | 299 ++++++++++++++++++ 3 files changed, 386 insertions(+), 79 deletions(-) create mode 100644 client/internal/templates/pkce_auth_msg_test.go diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 738d3e34f..48873f640 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -192,17 +192,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, if authError := query.Get(queryError); authError != "" { authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + if authErrorDesc != "" { + return nil, fmt.Errorf("authentication failed: %s", authErrorDesc) + } + return nil, fmt.Errorf("authentication failed: %s", authError) } // Prevent timing attacks on the state if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") + return nil, fmt.Errorf("authentication failed: Invalid state") } code := query.Get(queryCode) if code == "" { - return nil, fmt.Errorf("missing code") + return nil, fmt.Errorf("authentication failed: missing code") } return p.oAuthConfig.Exchange( @@ -231,7 +234,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, } if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err) } email, err := parseEmailFromIDToken(tokenInfo.IDToken) diff --git a/client/internal/templates/pkce-auth-msg.html b/client/internal/templates/pkce-auth-msg.html index 4825c48e7..175a6f05c 100644 --- a/client/internal/templates/pkce-auth-msg.html +++ b/client/internal/templates/pkce-auth-msg.html @@ -1,88 +1,93 @@ + - + + + + NetBird Login + + + - NetBird Login Successful + -
- -
- {{ if .Error }} - - - - -
-
- Login failed +
+
+ + +
+ + + + + + + + + + + + + + + + + + +
- {{ .Error }}. -
- {{ else }} - - - - -
-
- Login successful + +
+ +
+ + {{ if .Error }} + +
+ + + + +
+ {{ else }} + +
+ + + + +
+ {{ end }} + + +
+ {{ if .Error }} +

Login Failed

+ {{ else }} +

Login Successful

+ {{ end }} +
+ + + {{ if .Error }} +
+ {{ .Error }} +
+ {{ else }} + +
+ Your device is now registered and logged in to NetBird. You can now close this window. +
+ {{ end }} + +
- Your device is now registered and logged in to NetBird. -
- You can now close this window.
- {{ end }}
+ diff --git a/client/internal/templates/pkce_auth_msg_test.go b/client/internal/templates/pkce_auth_msg_test.go new file mode 100644 index 000000000..75b1c9e76 --- /dev/null +++ b/client/internal/templates/pkce_auth_msg_test.go @@ -0,0 +1,299 @@ +package templates + +import ( + "html/template" + "os" + "path/filepath" + "testing" +) + +func TestPKCEAuthMsgTemplate(t *testing.T) { + tests := []struct { + name string + data map[string]string + outputFile string + expectedTitle string + expectedInContent []string + notExpectedInContent []string + }{ + { + name: "error_state", + data: map[string]string{ + "Error": "authentication failed: invalid state", + }, + outputFile: "pkce-auth-error.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication failed: invalid state", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + { + name: "success_state", + data: map[string]string{ + // No error field means success + }, + outputFile: "pkce-auth-success.html", + expectedTitle: "Login Successful", + expectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird. You can now close this window.", + }, + notExpectedInContent: []string{ + "Login Failed", + }, + }, + { + name: "error_state_timeout", + data: map[string]string{ + "Error": "authentication timeout: request expired after 5 minutes", + }, + outputFile: "pkce-auth-timeout.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication timeout: request expired after 5 minutes", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the template + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + // Create temp directory for this test + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, tt.outputFile) + + // Create output file + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + + // Execute the template + if err := tmpl.Execute(file, tt.data); err != nil { + file.Close() + t.Fatalf("Failed to execute template: %v", err) + } + file.Close() + + t.Logf("Generated test output: %s", outputPath) + + // Read the generated file + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + contentStr := string(content) + + // Verify file has content + if len(contentStr) == 0 { + t.Error("Output file is empty") + } + + // Verify basic HTML structure + basicElements := []string{ + "", + "", + "", + "NetBird", + } + + for _, elem := range basicElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + + // Verify expected title + if !contains(contentStr, tt.expectedTitle) { + t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle) + } + + // Verify expected content is present + for _, expected := range tt.expectedInContent { + if !contains(contentStr, expected) { + t.Errorf("Expected HTML to contain '%s', but it was not found", expected) + } + } + + // Verify unexpected content is not present + for _, notExpected := range tt.notExpectedInContent { + if contains(contentStr, notExpected) { + t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected) + } + } + }) + } +} + +func TestPKCEAuthMsgTemplateValidation(t *testing.T) { + // Test that the template can be parsed without errors + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + // Test with empty data + t.Run("empty_data", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "empty-data.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + if err := tmpl.Execute(file, nil); err != nil { + t.Errorf("Template execution with nil data failed: %v", err) + } + }) + + // Test with error data + t.Run("with_error", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "with-error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{ + "Error": "test error message", + } + if err := tmpl.Execute(file, data); err != nil { + t.Errorf("Template execution with error data failed: %v", err) + } + }) +} + +func TestPKCEAuthMsgTemplateContent(t *testing.T) { + // Test that the template contains expected elements + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + t.Run("success_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "success.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{} + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for success indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Successful", + "NetBird", + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) + + t.Run("error_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + errorMsg := "test error message" + data := map[string]string{ + "Error": errorMsg, + } + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for error indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Failed", + errorMsg, + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && containsHelper(s, substr))) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From d71a82769c32238cf29bd1bb3bd3ec520c0c1ab9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 17 Nov 2025 17:10:41 +0100 Subject: [PATCH 059/118] [client,management] Rewrite the SSH feature (#4015) --- .../workflows/check-license-dependencies.yml | 52 +- client/android/client.go | 2 +- client/android/preferences.go | 88 + client/cmd/root.go | 3 - client/cmd/ssh.go | 866 +++++++++- client/cmd/ssh_exec_unix.go | 74 + client/cmd/ssh_sftp_unix.go | 94 ++ client/cmd/ssh_sftp_windows.go | 94 ++ client/cmd/ssh_test.go | 717 ++++++++ client/cmd/status.go | 2 +- client/cmd/testutil_test.go | 3 +- client/cmd/up.go | 68 + client/embed/embed.go | 70 +- client/firewall/uspfilter/filter.go | 12 +- client/firewall/uspfilter/filter_test.go | 136 ++ .../firewall/uspfilter/nat_stateful_test.go | 85 + client/internal/acl/manager.go | 17 - client/internal/acl/manager_test.go | 67 - client/internal/connect.go | 38 +- client/internal/debug/debug.go | 12 + client/internal/engine.go | 143 +- client/internal/engine_ssh.go | 355 ++++ client/internal/engine_test.go | 134 +- client/internal/login.go | 10 + client/internal/peer/conn.go | 2 +- client/internal/peer/env.go | 4 + client/internal/peer/status.go | 19 +- client/internal/profilemanager/config.go | 128 +- client/internal/profilemanager/config_test.go | 8 +- .../internal/profilemanager/profilemanager.go | 18 + client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 2 +- client/ios/NetBirdSDK/client.go | 2 +- client/proto/daemon.pb.go | 1458 +++++++++++++---- client/proto/daemon.proto | 177 +- client/proto/daemon_grpc.pb.go | 114 ++ client/server/jwt_cache.go | 79 + client/server/network.go | 2 +- client/server/server.go | 317 +++- client/server/server_test.go | 3 +- client/server/setconfig_test.go | 110 +- client/server/state_generic.go | 2 + client/server/state_linux.go | 2 + client/ssh/client.go | 118 -- client/ssh/client/client.go | 699 ++++++++ client/ssh/client/client_test.go | 512 ++++++ client/ssh/client/terminal_unix.go | 127 ++ client/ssh/client/terminal_windows.go | 265 +++ client/ssh/common.go | 171 ++ client/ssh/config/manager.go | 282 ++++ client/ssh/config/manager_test.go | 159 ++ client/ssh/config/shutdown_state.go | 22 + client/ssh/detection/detection.go | 99 ++ client/ssh/login.go | 53 - client/ssh/lookup.go | 14 - client/ssh/lookup_darwin.go | 51 - client/ssh/proxy/proxy.go | 392 +++++ client/ssh/proxy/proxy_test.go | 367 +++++ client/ssh/server.go | 280 ---- client/ssh/server/command_execution.go | 206 +++ client/ssh/server/command_execution_js.go | 52 + client/ssh/server/command_execution_unix.go | 329 ++++ .../ssh/server/command_execution_windows.go | 430 +++++ client/ssh/server/compatibility_test.go | 722 ++++++++ client/ssh/server/executor_unix.go | 253 +++ client/ssh/server/executor_unix_test.go | 262 +++ client/ssh/server/executor_windows.go | 570 +++++++ client/ssh/server/jwt_test.go | 629 +++++++ client/ssh/server/port_forwarding.go | 386 +++++ client/ssh/server/server.go | 712 ++++++++ client/ssh/server/server_config_test.go | 394 +++++ client/ssh/server/server_test.go | 441 +++++ client/ssh/server/session_handlers.go | 168 ++ client/ssh/server/session_handlers_js.go | 22 + client/ssh/server/sftp.go | 81 + client/ssh/server/sftp_js.go | 12 + client/ssh/server/sftp_test.go | 228 +++ client/ssh/server/sftp_unix.go | 71 + client/ssh/server/sftp_windows.go | 91 + client/ssh/server/shell.go | 180 ++ client/ssh/server/test.go | 45 + client/ssh/server/user_utils.go | 411 +++++ client/ssh/server/user_utils_js.go | 8 + client/ssh/server/user_utils_test.go | 908 ++++++++++ client/ssh/server/userswitching_js.go | 8 + client/ssh/server/userswitching_unix.go | 233 +++ client/ssh/server/userswitching_windows.go | 274 ++++ client/ssh/server/winpty/conpty.go | 487 ++++++ client/ssh/server/winpty/conpty_test.go | 290 ++++ client/ssh/server_mock.go | 46 - client/ssh/server_test.go | 123 -- client/ssh/{util.go => ssh.go} | 10 +- client/ssh/testutil/user_helpers.go | 172 ++ client/ssh/window_freebsd.go | 10 - client/ssh/window_unix.go | 14 - client/ssh/window_windows.go | 9 - client/status/status.go | 89 +- client/status/status_test.go | 17 +- client/system/info.go | 24 + client/ui/client_ui.go | 492 ++++-- client/wasm/cmd/main.go | 39 +- client/wasm/internal/ssh/client.go | 51 +- client/wasm/internal/ssh/key.go | 50 - go.mod | 16 +- go.sum | 56 +- management/internals/server/modules.go | 2 +- .../internals/shared/grpc/conversion.go | 67 +- management/internals/shared/grpc/server.go | 2 +- management/server/account.go | 28 +- management/server/account/manager.go | 11 +- management/server/account_test.go | 54 +- management/server/auth/manager.go | 17 +- management/server/auth/manager_mock.go | 15 +- management/server/auth/manager_test.go | 12 +- management/server/context/auth.go | 38 +- management/server/dns_test.go | 2 +- .../accounts/accounts_handler_test.go | 3 +- .../http/handlers/dns/dns_settings_handler.go | 2 +- .../handlers/dns/dns_settings_handler_test.go | 5 +- .../handlers/dns/nameservers_handler_test.go | 3 +- .../handlers/events/events_handler_test.go | 5 +- .../http/handlers/groups/groups_handler.go | 2 +- .../handlers/groups/groups_handler_test.go | 13 +- .../server/http/handlers/networks/handler.go | 6 +- .../handlers/networks/resources_handler.go | 4 +- .../http/handlers/networks/routers_handler.go | 4 +- .../http/handlers/peers/peers_handler_test.go | 7 +- .../policies/geolocation_handler_test.go | 7 +- .../handlers/policies/geolocations_handler.go | 4 +- .../handlers/policies/policies_handler.go | 2 +- .../policies/policies_handler_test.go | 9 +- .../policies/posture_checks_handler.go | 2 +- .../policies/posture_checks_handler_test.go | 7 +- .../handlers/routes/routes_handler_test.go | 3 +- .../handlers/setup_keys/setupkeys_handler.go | 2 +- .../setup_keys/setupkeys_handler_test.go | 7 +- .../server/http/handlers/users/pat_handler.go | 2 +- .../http/handlers/users/pat_handler_test.go | 7 +- .../http/handlers/users/users_handler_test.go | 35 +- .../server/http/middleware/auth_middleware.go | 35 +- .../http/middleware/auth_middleware_test.go | 63 +- .../testing/testing_tools/channel/channel.go | 15 +- management/server/idp/pocketid_test.go | 199 ++- management/server/management_proto_test.go | 2 +- management/server/management_test.go | 1 + management/server/mock_server/account_mock.go | 16 +- management/server/nameserver_test.go | 2 +- .../server/networks/resources/manager_test.go | 2 +- .../networks/resources/types/resource.go | 2 +- .../server/networks/routers/manager_test.go | 2 +- .../server/networks/routers/types/router.go | 2 +- management/server/peer_test.go | 8 +- management/server/posture/checks.go | 2 +- management/server/route_test.go | 2 +- .../server/types/route_firewall_rule.go | 2 +- management/server/user.go | 13 +- management/server/user_test.go | 26 +- relay/server/peer.go | 4 +- .../server => shared}/auth/jwt/extractor.go | 8 +- .../server => shared}/auth/jwt/validator.go | 0 shared/auth/user.go | 28 + shared/context/keys.go | 2 +- shared/management/client/client_test.go | 3 +- shared/management/operations/operation.go | 2 +- shared/management/proto/management.pb.go | 1414 +++++++++------- shared/management/proto/management.proto | 18 + shared/relay/client/dialer/quic/quic.go | 2 +- shared/relay/client/dialer/ws/ws.go | 2 +- shared/relay/constants.go | 2 +- version/url_windows.go | 6 +- 170 files changed, 18744 insertions(+), 2853 deletions(-) create mode 100644 client/cmd/ssh_exec_unix.go create mode 100644 client/cmd/ssh_sftp_unix.go create mode 100644 client/cmd/ssh_sftp_windows.go create mode 100644 client/cmd/ssh_test.go create mode 100644 client/firewall/uspfilter/nat_stateful_test.go create mode 100644 client/internal/engine_ssh.go create mode 100644 client/server/jwt_cache.go delete mode 100644 client/ssh/client.go create mode 100644 client/ssh/client/client.go create mode 100644 client/ssh/client/client_test.go create mode 100644 client/ssh/client/terminal_unix.go create mode 100644 client/ssh/client/terminal_windows.go create mode 100644 client/ssh/common.go create mode 100644 client/ssh/config/manager.go create mode 100644 client/ssh/config/manager_test.go create mode 100644 client/ssh/config/shutdown_state.go create mode 100644 client/ssh/detection/detection.go delete mode 100644 client/ssh/login.go delete mode 100644 client/ssh/lookup.go delete mode 100644 client/ssh/lookup_darwin.go create mode 100644 client/ssh/proxy/proxy.go create mode 100644 client/ssh/proxy/proxy_test.go delete mode 100644 client/ssh/server.go create mode 100644 client/ssh/server/command_execution.go create mode 100644 client/ssh/server/command_execution_js.go create mode 100644 client/ssh/server/command_execution_unix.go create mode 100644 client/ssh/server/command_execution_windows.go create mode 100644 client/ssh/server/compatibility_test.go create mode 100644 client/ssh/server/executor_unix.go create mode 100644 client/ssh/server/executor_unix_test.go create mode 100644 client/ssh/server/executor_windows.go create mode 100644 client/ssh/server/jwt_test.go create mode 100644 client/ssh/server/port_forwarding.go create mode 100644 client/ssh/server/server.go create mode 100644 client/ssh/server/server_config_test.go create mode 100644 client/ssh/server/server_test.go create mode 100644 client/ssh/server/session_handlers.go create mode 100644 client/ssh/server/session_handlers_js.go create mode 100644 client/ssh/server/sftp.go create mode 100644 client/ssh/server/sftp_js.go create mode 100644 client/ssh/server/sftp_test.go create mode 100644 client/ssh/server/sftp_unix.go create mode 100644 client/ssh/server/sftp_windows.go create mode 100644 client/ssh/server/shell.go create mode 100644 client/ssh/server/test.go create mode 100644 client/ssh/server/user_utils.go create mode 100644 client/ssh/server/user_utils_js.go create mode 100644 client/ssh/server/user_utils_test.go create mode 100644 client/ssh/server/userswitching_js.go create mode 100644 client/ssh/server/userswitching_unix.go create mode 100644 client/ssh/server/userswitching_windows.go create mode 100644 client/ssh/server/winpty/conpty.go create mode 100644 client/ssh/server/winpty/conpty_test.go delete mode 100644 client/ssh/server_mock.go delete mode 100644 client/ssh/server_test.go rename client/ssh/{util.go => ssh.go} (86%) create mode 100644 client/ssh/testutil/user_helpers.go delete mode 100644 client/ssh/window_freebsd.go delete mode 100644 client/ssh/window_unix.go delete mode 100644 client/ssh/window_windows.go delete mode 100644 client/wasm/internal/ssh/key.go rename {management/server => shared}/auth/jwt/extractor.go (92%) rename {management/server => shared}/auth/jwt/validator.go (100%) create mode 100644 shared/auth/user.go diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index 2a3e7d424..543ba2ab2 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -19,35 +19,37 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Check for problematic license dependencies - run: | - echo "Checking for dependencies on management/, signal/, and relay/ packages..." + - name: Check for problematic license dependencies + run: | + echo "Checking for dependencies on management/, signal/, and relay/ packages..." + echo "" - # Find all directories except the problematic ones and system dirs - FOUND_ISSUES=0 - find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do - echo "=== Checking $dir ===" - # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) - if [ ! -z "$RESULTS" ]; then - echo "❌ Found problematic dependencies:" - echo "$RESULTS" - FOUND_ISSUES=1 + # Find all directories except the problematic ones and system dirs + FOUND_ISSUES=0 + while IFS= read -r dir; do + echo "=== Checking $dir ===" + # Search for problematic imports, excluding test files + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + if [ -n "$RESULTS" ]; then + echo "❌ Found problematic dependencies:" + echo "$RESULTS" + FOUND_ISSUES=1 + else + echo "✓ No problematic dependencies found" + fi + done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort) + + echo "" + if [ $FOUND_ISSUES -eq 1 ]; then + echo "❌ Found dependencies on management/, signal/, or relay/ packages" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" + exit 1 else - echo "✓ No problematic dependencies found" + echo "" + echo "✅ All internal license dependencies are clean" fi - done - if [ $FOUND_ISSUES -eq 1 ]; then - echo "" - echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" - exit 1 - else - echo "" - echo "✅ All internal license dependencies are clean" - fi check-external-licenses: name: Check External GPL/AGPL Licenses diff --git a/client/android/client.go b/client/android/client.go index d2d0c37f6..86fb1445d 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -17,9 +17,9 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile diff --git a/client/android/preferences.go b/client/android/preferences.go index 9a5d6bb21..c3c8eb3fb 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) { p.configInput.ServerSSHAllowed = &allowed } +// GetEnableSSHRoot reads SSH root login setting from config file +func (p *Preferences) GetEnableSSHRoot() (bool, error) { + if p.configInput.EnableSSHRoot != nil { + return *p.configInput.EnableSSHRoot, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRoot == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRoot, err +} + +// SetEnableSSHRoot stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRoot(enabled bool) { + p.configInput.EnableSSHRoot = &enabled +} + +// GetEnableSSHSFTP reads SSH SFTP setting from config file +func (p *Preferences) GetEnableSSHSFTP() (bool, error) { + if p.configInput.EnableSSHSFTP != nil { + return *p.configInput.EnableSSHSFTP, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHSFTP == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHSFTP, err +} + +// SetEnableSSHSFTP stores the given value and waits for commit +func (p *Preferences) SetEnableSSHSFTP(enabled bool) { + p.configInput.EnableSSHSFTP = &enabled +} + +// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file +func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) { + if p.configInput.EnableSSHLocalPortForwarding != nil { + return *p.configInput.EnableSSHLocalPortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHLocalPortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHLocalPortForwarding, err +} + +// SetEnableSSHLocalPortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) { + p.configInput.EnableSSHLocalPortForwarding = &enabled +} + +// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file +func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) { + if p.configInput.EnableSSHRemotePortForwarding != nil { + return *p.configInput.EnableSSHRemotePortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRemotePortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRemotePortForwarding, err +} + +// SetEnableSSHRemotePortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) { + p.configInput.EnableSSHRemotePortForwarding = &enabled +} + // GetBlockInbound reads block inbound setting from config file func (p *Preferences) GetBlockInbound() (bool, error) { if p.configInput.BlockInbound != nil { diff --git a/client/cmd/root.go b/client/cmd/root.go index 11e5228f1..9f2eb109c 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -35,7 +35,6 @@ const ( wireguardPortFlag = "wireguard-port" networkMonitorFlag = "network-monitor" disableAutoConnectFlag = "disable-auto-connect" - serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" enableLazyConnectionFlag = "enable-lazy-connection" @@ -64,7 +63,6 @@ var ( customDNSAddress string rosenpassEnabled bool rosenpassPermissive bool - serverSSHAllowed bool interfaceName string wireguardPort uint16 networkMonitor bool @@ -176,7 +174,6 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") - upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 5358ddacb..70c7dbcff 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -3,125 +3,809 @@ package cmd import ( "context" "errors" + "flag" "fmt" + "net" "os" "os/signal" + "os/user" + "slices" + "strconv" "strings" "syscall" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/internal/profilemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + sshproxy "github.com/netbirdio/netbird/client/ssh/proxy" + sshserver "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/util" ) -var ( - port int - userName = "root" - host string +const ( + sshUsernameDesc = "SSH username" + hostArgumentRequired = "host argument required" + + serverSSHAllowedFlag = "allow-server-ssh" + enableSSHRootFlag = "enable-ssh-root" + enableSSHSFTPFlag = "enable-ssh-sftp" + enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" + enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" + disableSSHAuthFlag = "disable-ssh-auth" + sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl" ) -var sshCmd = &cobra.Command{ - Use: "ssh [user@]host", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errors.New("requires a host argument") - } +var ( + port int + username string + host string + command string + localForwards []string + remoteForwards []string + strictHostKeyChecking bool + knownHostsFile string + identityFile string + skipCachedToken bool + requestPTY bool +) - split := strings.Split(args[0], "@") - if len(split) == 2 { - userName = split[0] - host = split[1] - } else { - host = args[0] - } +var ( + serverSSHAllowed bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int +) - return nil - }, - Short: "Connect to a remote SSH server", - RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - SetFlagsFromEnvVars(cmd) +func init() { + upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer") + upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication") + upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)") - cmd.SetOut(cmd.OutOrStdout()) + sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") + sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) + sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)") + sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation") + sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") + sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") + sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") + _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") + sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") - err := util.InitLog(logLevel, util.LogConsole) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } + sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") + sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") - if !util.IsAdmin() { - cmd.Printf("error: you must have Administrator privileges to run this command\n") - return nil - } - - ctx := internal.CtxInitState(cmd.Context()) - - sm := profilemanager.NewServiceManager(configPath) - activeProf, err := sm.GetActiveProfileState() - if err != nil { - return fmt.Errorf("get active profile: %v", err) - } - profPath, err := activeProf.FilePath() - if err != nil { - return fmt.Errorf("get active profile path: %v", err) - } - - config, err := profilemanager.ReadConfig(profPath) - if err != nil { - return fmt.Errorf("read profile config: %v", err) - } - - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) - sshctx, cancel := context.WithCancel(ctx) - - go func() { - // blocking - if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { - cmd.Printf("Error: %v\n", err) - os.Exit(1) - } - cancel() - }() - - select { - case <-sig: - cancel() - case <-sshctx.Done(): - } - - return nil - }, + sshCmd.AddCommand(sshSftpCmd) + sshCmd.AddCommand(sshProxyCmd) + sshCmd.AddCommand(sshDetectCmd) } -func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error { - c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey) - if err != nil { - cmd.Printf("Error: %v\n", err) - cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + - "\nYou can verify the connection by running:\n\n" + - " netbird status\n\n") - return err - } - go func() { - <-ctx.Done() - err = c.Close() - if err != nil { - return +var sshCmd = &cobra.Command{ + Use: "ssh [flags] [user@]host [command]", + Short: "Connect to a NetBird peer via SSH", + Long: `Connect to a NetBird peer using SSH with support for port forwarding. + +Port Forwarding: + -L [bind_address:]port:host:hostport Local port forwarding + -L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket + -R [bind_address:]port:host:hostport Remote port forwarding + -R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket + +SSH Options: + -p, --port int Remote SSH port (default 22) + -u, --user string SSH username + --login string SSH username (alias for --user) + -t, --tty Force pseudo-terminal allocation + --strict-host-key-checking Enable strict host key checking (default: true) + -o, --known-hosts string Path to known_hosts file + +Examples: + netbird ssh peer-hostname + netbird ssh root@peer-hostname + netbird ssh --login root peer-hostname + netbird ssh peer-hostname ls -la + netbird ssh peer-hostname whoami + netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen + netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo + netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding + netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding + netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces + netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`, + DisableFlagParsing: true, + Args: validateSSHArgsWithoutFlagParsing, + RunE: sshFn, + Aliases: []string{"ssh"}, +} + +func sshFn(cmd *cobra.Command, args []string) error { + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return cmd.Help() } + } + + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(cmd) + + cmd.SetOut(cmd.OutOrStdout()) + + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + if err := util.InitLog(logLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := internal.CtxInitState(cmd.Context()) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) + sshctx, cancel := context.WithCancel(ctx) + + errCh := make(chan error, 1) + go func() { + if err := runSSH(sshctx, host, cmd); err != nil { + errCh <- err + } + cancel() }() - err = c.OpenTerminal() - if err != nil { + select { + case <-sig: + cancel() + <-sshctx.Done() + return nil + case err := <-errCh: return err + case <-sshctx.Done(): } return nil } -func init() { - sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort)) +// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes +func getEnvOrDefault(flagName, defaultValue string) string { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + return envValue + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + return envValue + } + return defaultValue +} + +// resetSSHGlobals sets SSH globals to their default values +func resetSSHGlobals() { + port = sshserver.DefaultSSHPort + username = "" + host = "" + command = "" + localForwards = nil + remoteForwards = nil + strictHostKeyChecking = true + knownHostsFile = "" + identityFile = "" +} + +// parseCustomSSHFlags extracts -L, -R flags and returns filtered args +func parseCustomSSHFlags(args []string) ([]string, []string, []string) { + var localForwardFlags []string + var remoteForwardFlags []string + var filteredArgs []string + + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case strings.HasPrefix(arg, "-L"): + localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags) + case strings.HasPrefix(arg, "-R"): + remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags) + default: + filteredArgs = append(filteredArgs, arg) + } + } + + return filteredArgs, localForwardFlags, remoteForwardFlags +} + +func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) { + if arg == "-L" || arg == "-R" { + if i+1 < len(args) { + flags = append(flags, args[i+1]) + i++ + } + } else if len(arg) > 2 { + flags = append(flags, arg[2:]) + } + return flags, i +} + +// extractGlobalFlags parses global flags that were passed before 'ssh' command +func extractGlobalFlags(args []string) { + sshPos := findSSHCommandPosition(args) + if sshPos == -1 { + return + } + + globalArgs := args[:sshPos] + parseGlobalArgs(globalArgs) +} + +// findSSHCommandPosition locates the 'ssh' command in the argument list +func findSSHCommandPosition(args []string) int { + for i, arg := range args { + if arg == "ssh" { + return i + } + } + return -1 +} + +const ( + configFlag = "config" + logLevelFlag = "log-level" + logFileFlag = "log-file" +) + +// parseGlobalArgs processes the global arguments and sets the corresponding variables +func parseGlobalArgs(globalArgs []string) { + flagHandlers := map[string]func(string){ + configFlag: func(value string) { configPath = value }, + logLevelFlag: func(value string) { logLevel = value }, + logFileFlag: func(value string) { + if !slices.Contains(logFiles, value) { + logFiles = append(logFiles, value) + } + }, + } + + shortFlags := map[string]string{ + "c": configFlag, + "l": logLevelFlag, + } + + for i := 0; i < len(globalArgs); i++ { + arg := globalArgs[i] + + if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled { + i = nextIndex + } + } +} + +// parseFlag handles generic flag parsing for both long and short forms +func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) { + if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + } + + if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + 1 + } + + return false, currentIndex +} + +type parsedFlag struct { + flagName string + value string +} + +// parseEqualsFormat handles --flag=value and -f=value formats +func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if !strings.Contains(arg, "=") { + return parsedFlag{}, false + } + + parts := strings.SplitN(arg, "=", 2) + if len(parts) != 2 { + return parsedFlag{}, false + } + + if strings.HasPrefix(parts[0], "--") { + flagName := strings.TrimPrefix(parts[0], "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: parts[1]}, true + } + } + + if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 { + shortFlag := strings.TrimPrefix(parts[0], "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: parts[1]}, true + } + } + } + + return parsedFlag{}, false +} + +// parseSpacedFormat handles --flag value and -f value formats +func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if currentIndex+1 >= len(args) { + return parsedFlag{}, false + } + + if strings.HasPrefix(arg, "--") { + flagName := strings.TrimPrefix(arg, "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true + } + } + + if strings.HasPrefix(arg, "-") && len(arg) == 2 { + shortFlag := strings.TrimPrefix(arg, "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true + } + } + } + + return parsedFlag{}, false +} + +// createSSHFlagSet creates and configures the flag set for SSH command parsing +// sshFlags contains all SSH-related flags and parameters +type sshFlags struct { + Port int + Username string + Login string + RequestPTY bool + StrictHostKeyChecking bool + KnownHostsFile string + IdentityFile string + SkipCachedToken bool + ConfigPath string + LogLevel string + LocalForwards []string + RemoteForwards []string + Host string + Command string +} + +func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { + defaultConfigPath := getEnvOrDefault("CONFIG", configPath) + defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + + fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) + fs.SetOutput(nil) + + flags := &sshFlags{} + + fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port") + fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port") + fs.StringVar(&flags.Username, "u", "", sshUsernameDesc) + fs.StringVar(&flags.Username, "user", "", sshUsernameDesc) + fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)") + fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation") + fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation") + + fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking") + fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file") + fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file") + fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") + fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") + fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + + fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level") + fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level") + + return fs, flags +} + +func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + resetSSHGlobals() + + if len(os.Args) > 2 { + extractGlobalFlags(os.Args[1:]) + } + + filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args) + + fs, flags := createSSHFlagSet() + + if err := fs.Parse(filteredArgs); err != nil { + if errors.Is(err, flag.ErrHelp) { + return nil + } + return err + } + + remaining := fs.Args() + if len(remaining) < 1 { + return errors.New(hostArgumentRequired) + } + + port = flags.Port + if flags.Username != "" { + username = flags.Username + } else if flags.Login != "" { + username = flags.Login + } + + requestPTY = flags.RequestPTY + strictHostKeyChecking = flags.StrictHostKeyChecking + knownHostsFile = flags.KnownHostsFile + identityFile = flags.IdentityFile + skipCachedToken = flags.SkipCachedToken + + if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { + configPath = flags.ConfigPath + } + if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) { + logLevel = flags.LogLevel + } + + localForwards = localForwardFlags + remoteForwards = remoteForwardFlags + + return parseHostnameAndCommand(remaining) +} + +func parseHostnameAndCommand(args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + arg := args[0] + if strings.Contains(arg, "@") { + parts := strings.SplitN(arg, "@", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return errors.New("invalid user@host format") + } + if username == "" { + username = parts[0] + } + host = parts[1] + } else { + host = arg + } + + if username == "" { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + username = sudoUser + } else if currentUser, err := user.Current(); err == nil { + username = currentUser.Username + } else { + username = "root" + } + } + + // Everything after hostname becomes the command + if len(args) > 1 { + command = strings.Join(args[1:], " ") + } + + return nil +} + +func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { + target := fmt.Sprintf("%s:%d", addr, port) + c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ + KnownHostsFile: knownHostsFile, + IdentityFile: identityFile, + DaemonAddr: daemonAddr, + SkipCachedToken: skipCachedToken, + InsecureSkipVerify: !strictHostKeyChecking, + }) + + if err != nil { + cmd.Printf("Failed to connect to %s@%s\n", username, target) + cmd.Printf("\nTroubleshooting steps:\n") + cmd.Printf(" 1. Check peer connectivity: netbird status -d\n") + cmd.Printf(" 2. Verify SSH server is enabled on the peer\n") + cmd.Printf(" 3. Ensure correct hostname/IP is used\n") + return fmt.Errorf("dial %s: %w", target, err) + } + + sshCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-sshCtx.Done() + if err := c.Close(); err != nil { + cmd.Printf("Error closing SSH connection: %v\n", err) + } + }() + + if err := startPortForwarding(sshCtx, c, cmd); err != nil { + return fmt.Errorf("start port forwarding: %w", err) + } + + if command != "" { + return executeSSHCommand(sshCtx, c, command) + } + return openSSHTerminal(sshCtx, c) +} + +// executeSSHCommand executes a command over SSH. +func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error { + var err error + if requestPTY { + err = c.ExecuteCommandWithPTY(ctx, command) + } else { + err = c.ExecuteCommandWithIO(ctx, command) + } + + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitErr *ssh.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitStatus()) + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote command exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("execute command: %w", err) + } + return nil +} + +// openSSHTerminal opens an interactive SSH terminal. +func openSSHTerminal(ctx context.Context, c *sshclient.Client) error { + if err := c.OpenTerminal(ctx); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote terminal exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("open terminal: %w", err) + } + return nil +} + +// startPortForwarding starts local and remote port forwarding based on command line flags +func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error { + for _, forward := range localForwards { + if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("local port forward %s: %w", forward, err) + } + } + + for _, forward := range remoteForwards { + if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("remote port forward %s: %w", forward, err) + } + } + + return nil +} + +// parseAndStartLocalForward parses and starts a local port forward (-L) +func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + localAddr, remoteAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) + + go func() { + if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Local port forward error: %v\n", err) + } + }() + + return nil +} + +// parseAndStartRemoteForward parses and starts a remote port forward (-R) +func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + remoteAddr, localAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) + + go func() { + if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Remote port forward error: %v\n", err) + } + }() + + return nil +} + +// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". +// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". +func parsePortForwardSpec(spec string) (string, string, error) { + // Support formats: + // port:host:hostport -> localhost:port -> host:hostport + // host:port:host:hostport -> host:port -> host:hostport + // [host]:port:host:hostport -> [host]:port -> host:hostport + // port:unix_socket_path -> localhost:port -> unix_socket_path + // host:port:unix_socket_path -> host:port -> unix_socket_path + + if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") { + return parseIPv6ForwardSpec(spec) + } + + parts := strings.Split(spec, ":") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec) + } + + switch len(parts) { + case 2: + return parseTwoPartForwardSpec(parts, spec) + case 3: + return parseThreePartForwardSpec(parts) + case 4: + return parseFourPartForwardSpec(parts) + default: + return "", "", fmt.Errorf("invalid port forward specification: %s", spec) + } +} + +// parseTwoPartForwardSpec handles "port:unix_socket" format. +func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) { + if isUnixSocket(parts[1]) { + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + return localAddr, remoteAddr, nil + } + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec) +} + +// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats. +func parseThreePartForwardSpec(parts []string) (string, string, error) { + if isUnixSocket(parts[2]) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + return localAddr, remoteAddr, nil + } + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// parseFourPartForwardSpec handles "host:port:host:hostport" format. +func parseFourPartForwardSpec(parts []string) (string, string, error) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + ":" + parts[3] + return localAddr, remoteAddr, nil +} + +// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format. +func parseIPv6ForwardSpec(spec string) (string, string, error) { + idx := strings.Index(spec, "]:") + if idx == -1 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec) + } + + ipv6Host := spec[:idx+1] + remaining := spec[idx+2:] + + parts := strings.Split(remaining, ":") + if len(parts) != 3 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec) + } + + localAddr := ipv6Host + ":" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// isUnixSocket checks if a path is a Unix socket path. +func isUnixSocket(path string) bool { + return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") +} + +// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +func normalizeLocalHost(host string) string { + if host == "*" { + return "0.0.0.0" + } + return host +} + +var sshProxyCmd = &cobra.Command{ + Use: "proxy ", + Short: "Internal SSH proxy for native SSH client integration", + Long: "Internal command used by SSH ProxyCommand to handle JWT authentication", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshProxyFn, +} + +func sshProxyFn(cmd *cobra.Command, args []string) error { + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + if err := util.InitLog(logLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) + if err != nil { + return fmt.Errorf("create SSH proxy: %w", err) + } + defer func() { + if err := proxy.Close(); err != nil { + log.Debugf("close SSH proxy: %v", err) + } + }() + + if err := proxy.Connect(cmd.Context()); err != nil { + return fmt.Errorf("SSH proxy: %w", err) + } + + return nil +} + +var sshDetectCmd = &cobra.Command{ + Use: "detect ", + Short: "Detect if a host is running NetBird SSH", + Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshDetectFn, +} + +func sshDetectFn(cmd *cobra.Command, args []string) error { + if err := util.InitLog(logLevel, "console"); err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port) + if err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + os.Exit(serverType.ExitCode()) + return nil } diff --git a/client/cmd/ssh_exec_unix.go b/client/cmd/ssh_exec_unix.go new file mode 100644 index 000000000..2412f072c --- /dev/null +++ b/client/cmd/ssh_exec_unix.go @@ -0,0 +1,74 @@ +//go:build unix + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sshExecUID uint32 + sshExecGID uint32 + sshExecGroups []uint + sshExecWorkingDir string + sshExecShell string + sshExecCommand string + sshExecPTY bool +) + +// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping +var sshExecCmd = &cobra.Command{ + Use: "exec", + Short: "Internal SSH execution with privilege dropping (hidden)", + Hidden: true, + RunE: runSSHExec, +} + +func init() { + sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID") + sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID") + sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)") + sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory") + sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute") + sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)") + sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute") + + if err := sshExecCmd.MarkFlagRequired("uid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err) + os.Exit(1) + } + if err := sshExecCmd.MarkFlagRequired("gid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err) + os.Exit(1) + } + + sshCmd.AddCommand(sshExecCmd) +} + +// runSSHExec handles the SSH exec subcommand execution. +func runSSHExec(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sshExecGroups { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sshExecUID, + GID: sshExecGID, + Groups: groups, + WorkingDir: sshExecWorkingDir, + Shell: sshExecShell, + Command: sshExecCommand, + PTY: sshExecPTY, + } + + privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config) + return nil +} diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go new file mode 100644 index 000000000..c06aab017 --- /dev/null +++ b/client/cmd/ssh_sftp_unix.go @@ -0,0 +1,94 @@ +//go:build unix + +package cmd + +import ( + "errors" + "io" + "os" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpUID uint32 + sftpGID uint32 + sftpGroupsInt []uint + sftpWorkingDir string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with privilege dropping (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID") + sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID") + sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)") + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sftpGroupsInt { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sftpUID, + GID: sftpGID, + Groups: groups, + WorkingDir: sftpWorkingDir, + Shell: "", + Command: "", + } + + log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + cmd.PrintErrf("privilege drop failed: %v\n", err) + os.Exit(sshserver.ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Tracef("starting SFTP server with dropped privileges") + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeShellExecFail) + } + + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeSuccess) + return nil +} diff --git a/client/cmd/ssh_sftp_windows.go b/client/cmd/ssh_sftp_windows.go new file mode 100644 index 000000000..ffd2d1148 --- /dev/null +++ b/client/cmd/ssh_sftp_windows.go @@ -0,0 +1,94 @@ +//go:build windows + +package cmd + +import ( + "errors" + "fmt" + "io" + "os" + "os/user" + "strings" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpWorkingDir string + windowsUsername string + windowsDomain string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with user switching for Windows (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") + sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching") + sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + return sftpMainDirect(cmd) +} + +func sftpMainDirect(cmd *cobra.Command) error { + currentUser, err := user.Current() + if err != nil { + cmd.PrintErrf("failed to get current user: %v\n", err) + os.Exit(sshserver.ExitCodeValidationFail) + } + + if windowsUsername != "" { + expectedUsername := windowsUsername + if windowsDomain != "" { + expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername) + } + if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) { + cmd.PrintErrf("user switching failed\n") + os.Exit(sshserver.ExitCodeValidationFail) + } + } + + log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name) + + if sftpWorkingDir != "" { + if err := os.Chdir(sftpWorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Debugf("starting SFTP server") + exitCode := sshserver.ExitCodeSuccess + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + exitCode = sshserver.ExitCodeShellExecFail + } + + if err := sftpServer.Close(); err != nil { + log.Debugf("SFTP server close error: %v", err) + } + + os.Exit(exitCode) + return nil +} diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go new file mode 100644 index 000000000..43291fa87 --- /dev/null +++ b/client/cmd/ssh_test.go @@ -0,0 +1,717 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSHCommand_FlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedUser string + expectedPort int + expectedCmd string + expectError bool + }{ + { + name: "basic host", + args: []string{"hostname"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "user@host format", + args: []string{"user@hostname"}, + expectedHost: "hostname", + expectedUser: "user", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "host with command", + args: []string{"hostname", "echo", "hello"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "echo hello", + }, + { + name: "command with flags should be preserved", + args: []string{"hostname", "ls", "-la", "/tmp"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "ls -la /tmp", + }, + { + name: "double dash separator", + args: []string{"hostname", "--", "ls", "-la"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "-- ls -la", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + // Mock command for testing + cmd := sshCmd + cmd.SetArgs(tt.args) + + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + if tt.expectedUser != "" { + assert.Equal(t, tt.expectedUser, username, "username mismatch") + } + assert.Equal(t, tt.expectedPort, port, "port mismatch") + assert.Equal(t, tt.expectedCmd, command, "command mismatch") + }) + } +} + +func TestSSHCommand_FlagConflictPrevention(t *testing.T) { + // Test that SSH flags don't conflict with command flags + tests := []struct { + name string + args []string + expectedCmd string + description string + }{ + { + name: "ls with -la flags", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + description: "ls flags should be passed to remote command", + }, + { + name: "grep with -r flag", + args: []string{"hostname", "grep", "-r", "pattern", "/path"}, + expectedCmd: "grep -r pattern /path", + description: "grep flags should be passed to remote command", + }, + { + name: "ps with aux flags", + args: []string{"hostname", "ps", "aux"}, + expectedCmd: "ps aux", + description: "ps flags should be passed to remote command", + }, + { + name: "command with double dash", + args: []string{"hostname", "--", "ls", "-la"}, + expectedCmd: "-- ls -la", + description: "double dash should be preserved in command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_NonInteractiveExecution(t *testing.T) { + // Test that commands with arguments should execute the command and exit, + // not drop to an interactive shell + tests := []struct { + name string + args []string + expectedCmd string + shouldExit bool + description string + }{ + { + name: "ls command should execute and exit", + args: []string{"hostname", "ls"}, + expectedCmd: "ls", + shouldExit: true, + description: "ls command should execute and exit, not drop to shell", + }, + { + name: "ls with flags should execute and exit", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + shouldExit: true, + description: "ls with flags should execute and exit, not drop to shell", + }, + { + name: "pwd command should execute and exit", + args: []string{"hostname", "pwd"}, + expectedCmd: "pwd", + shouldExit: true, + description: "pwd command should execute and exit, not drop to shell", + }, + { + name: "echo command should execute and exit", + args: []string{"hostname", "echo", "hello"}, + expectedCmd: "echo hello", + shouldExit: true, + description: "echo command should execute and exit, not drop to shell", + }, + { + name: "no command should open shell", + args: []string{"hostname"}, + expectedCmd: "", + shouldExit: false, + description: "no command should open interactive shell", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // When command is present, it should execute the command and exit + // When command is empty, it should open interactive shell + hasCommand := command != "" + assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior") + }) + } +} + +func TestSSHCommand_FlagHandling(t *testing.T) { + // Test that flags after hostname are not parsed by netbird but passed to SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "ls with -la flag should not be parsed by netbird", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "ls -la should be passed as SSH command, not parsed as netbird flags", + }, + { + name: "command with netbird-like flags should be passed through", + args: []string{"hostname", "echo", "--help"}, + expectedHost: "hostname", + expectedCmd: "echo --help", + expectError: false, + description: "--help should be passed to echo, not parsed by netbird", + }, + { + name: "command with -p flag should not conflict with SSH port flag", + args: []string{"hostname", "ps", "-p", "1234"}, + expectedHost: "hostname", + expectedCmd: "ps -p 1234", + expectError: false, + description: "ps -p should be passed to ps command, not parsed as port", + }, + { + name: "tar with flags should be passed through", + args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"}, + expectedHost: "hostname", + expectedCmd: "tar -czf backup.tar.gz /home", + expectError: false, + description: "tar flags should be passed to tar command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_RegressionFlagParsing(t *testing.T) { + // Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la" + // should not parse -la as netbird flags but pass them to the SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "original issue: ls -la should be preserved", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "The original failing case should now work", + }, + { + name: "ls -l should be preserved", + args: []string{"hostname", "ls", "-l"}, + expectedHost: "hostname", + expectedCmd: "ls -l", + expectError: false, + description: "Single letter flags should be preserved", + }, + { + name: "SSH port flag should work", + args: []string{"-p", "2222", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedCmd: "ls -la", + expectError: false, + description: "SSH -p flag should be parsed, command flags preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // Check port for the test case with -p flag + if len(tt.args) > 0 && tt.args[0] == "-p" { + assert.Equal(t, 2222, port, "port should be parsed from -p flag") + } + }) + } +} + +func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectError bool + description string + }{ + { + name: "local port forwarding -L", + args: []string{"-L", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Single -L flag should be parsed correctly", + }, + { + name: "remote port forwarding -R", + args: []string{"-R", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectError: false, + description: "Single -R flag should be parsed correctly", + }, + { + name: "multiple local port forwards", + args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"}, + expectedRemote: []string{}, + expectError: false, + description: "Multiple -L flags should be parsed correctly", + }, + { + name: "multiple remote port forwards", + args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"}, + expectError: false, + description: "Multiple -R flags should be parsed correctly", + }, + { + name: "mixed local and remote forwards", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectError: false, + description: "Mixed -L and -R flags should be parsed correctly", + }, + { + name: "port forwarding with bind address", + args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"127.0.0.1:8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with bind address should work", + }, + { + name: "port forwarding with command", + args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with command should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + }) + } +} + +func TestParsePortForward(t *testing.T) { + tests := []struct { + name string + spec string + expectedLocal string + expectedRemote string + expectError bool + description string + }{ + { + name: "simple port forward", + spec: "8080:localhost:80", + expectedLocal: "localhost:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Simple port:host:port format should work", + }, + { + name: "port forward with bind address", + spec: "127.0.0.1:8080:localhost:80", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "bind_address:port:host:port format should work", + }, + { + name: "port forward to different host", + spec: "8080:example.com:443", + expectedLocal: "localhost:8080", + expectedRemote: "example.com:443", + expectError: false, + description: "Forwarding to different host should work", + }, + { + name: "port forward with IPv6 (needs bracket support)", + spec: "::1:8080:localhost:80", + expectError: true, + description: "IPv6 without brackets fails as expected (feature to implement)", + }, + { + name: "invalid format - too few parts", + spec: "8080:localhost", + expectError: true, + description: "Invalid format with too few parts should fail", + }, + { + name: "invalid format - too many parts", + spec: "127.0.0.1:8080:localhost:80:extra", + expectError: true, + description: "Invalid format with too many parts should fail", + }, + { + name: "empty spec", + spec: "", + expectError: true, + description: "Empty spec should fail", + }, + { + name: "unix socket local forward", + spec: "8080:/tmp/socket", + expectedLocal: "localhost:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket forwarding should work", + }, + { + name: "unix socket with bind address", + spec: "127.0.0.1:8080:/tmp/socket", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket with bind address should work", + }, + { + name: "wildcard bind all interfaces", + spec: "*:8080:localhost:80", + expectedLocal: "0.0.0.0:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Wildcard * should bind to all interfaces (0.0.0.0)", + }, + { + name: "wildcard for port only", + spec: "8080:*:80", + expectedLocal: "localhost:8080", + expectedRemote: "*:80", + expectError: false, + description: "Wildcard in remote host should be preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec) + + if tt.expectError { + assert.Error(t, err, tt.description) + return + } + + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address") + assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address") + }) + } +} + +func TestSSHCommand_IntegrationPortForwarding(t *testing.T) { + // Integration test for port forwarding with the actual SSH command implementation + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectedCmd string + description string + }{ + { + name: "local forward with command", + args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectedCmd: "echo test", + description: "Local forwarding should work with commands", + }, + { + name: "remote forward with command", + args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectedCmd: "ls -la", + description: "Remote forwarding should work with commands", + }, + { + name: "multiple forwards with user and command", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectedCmd: "ps aux", + description: "Complex case with multiple forwards, user, and command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + assert.Equal(t, tt.expectedCmd, command, tt.description+" - command") + }) + } +} + +func TestSSHCommand_ParameterIsolation(t *testing.T) { + tests := []struct { + name string + args []string + expectedCmd string + }{ + { + name: "cmd flag passed as command", + args: []string{"hostname", "--cmd", "echo test"}, + expectedCmd: "--cmd echo test", + }, + { + name: "uid flag passed as command", + args: []string{"hostname", "--uid", "1000"}, + expectedCmd: "--uid 1000", + }, + { + name: "shell flag passed as command", + args: []string{"hostname", "--shell", "/bin/bash"}, + expectedCmd: "--shell /bin/bash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + require.NoError(t, err) + + assert.Equal(t, "hostname", host) + assert.Equal(t, tt.expectedCmd, command) + }) + } +} + +func TestSSHCommand_InvalidFlagRejection(t *testing.T) { + // Test that invalid flags are properly rejected and not misinterpreted as hostnames + tests := []struct { + name string + args []string + description string + }{ + { + name: "invalid long flag before hostname", + args: []string{"--invalid-flag", "hostname"}, + description: "Invalid flag should return parse error, not treat flag as hostname", + }, + { + name: "invalid short flag before hostname", + args: []string{"-x", "hostname"}, + description: "Invalid short flag should return parse error", + }, + { + name: "invalid flag with value before hostname", + args: []string{"--invalid-option=value", "hostname"}, + description: "Invalid flag with value should return parse error", + }, + { + name: "typo in known flag", + args: []string{"--por", "2222", "hostname"}, + description: "Typo in flag name should return parse error (not silently ignored)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + + // Should return an error for invalid flags + assert.Error(t, err, tt.description) + + // Should not have set host to the invalid flag + assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname") + }) + } +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6e57ceb89..06460a6a7 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { case yamlFlag: statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) default: - statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false) + statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) } if err != nil { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 78bb0476b..e7b0279e8 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -117,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 80175f7be..140ba2cb2 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro if cmd.Flag(serverSSHAllowedFlag).Changed { req.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + req.EnableSSHRoot = &enableSSHRoot + } + if cmd.Flag(enableSSHSFTPFlag).Changed { + req.EnableSSHSFTP = &enableSSHSFTP + } + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + if cmd.Flag(disableSSHAuthFlag).Changed { + req.DisableSSHAuth = &disableSSHAuth + } + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { log.Errorf("parse interface name: %v", err) @@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + ic.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + ic.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + ic.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + ic.SSHJWTCacheTTL = &sshJWTCacheTTL + } + if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { return nil, err @@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + loginRequest.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + loginRequest.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + loginRequest.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + if cmd.Flag(disableAutoConnectFlag).Changed { loginRequest.DisableAutoConnect = &autoConnectDisabled } diff --git a/client/embed/embed.go b/client/embed/embed.go index e918235ed..3090ca6a2 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -18,12 +18,16 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" ) -var ErrClientAlreadyStarted = errors.New("client already started") -var ErrClientNotStarted = errors.New("client not started") -var ErrConfigNotInitialized = errors.New("config not initialized") +var ( + ErrClientAlreadyStarted = errors.New("client already started") + ErrClientNotStarted = errors.New("client not started") + ErrEngineNotStarted = errors.New("engine not started") + ErrConfigNotInitialized = errors.New("config not initialized") +) // Client manages a netbird embedded client instance. type Client struct { @@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) { // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { - c.mu.Lock() - connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, ErrClientNotStarted - } - c.mu.Unlock() - - engine := connect.Engine() - if engine == nil { - return nil, errors.New("engine not started") + engine, err := c.getEngine() + if err != nil { + return nil, err } nsnet, err := engine.GetNet() @@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } +// DialContext dials a network address in the netbird network with context +func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return c.Dial(ctx, network, address) +} + // ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { @@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client { } } -func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { +// VerifySSHHostKey verifies an SSH host key against stored peer keys. +// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, +// ErrNoStoredKey if peer has no stored key, or an error for verification failures. +func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { + engine, err := c.getEngine() + if err != nil { + return err + } + + storedKey, found := engine.GetPeerSSHKey(peerAddress) + if !found { + return sshcommon.ErrPeerNotFound + } + + return sshcommon.VerifyHostKey(storedKey, key, peerAddress) +} + +// getEngine safely retrieves the engine from the client with proper locking. +// Returns ErrClientNotStarted if the client is not started. +// Returns ErrEngineNotStarted if the engine is not available. +func (c *Client) getEngine() (*internal.Engine, error) { c.mu.Lock() connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, netip.Addr{}, errors.New("client not started") - } c.mu.Unlock() + if connect == nil { + return nil, ErrClientNotStarted + } + engine := connect.Engine() if engine == nil { - return nil, netip.Addr{}, errors.New("engine not started") + return nil, ErrEngineNotStarted + } + + return engine, nil +} + +func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { + engine, err := c.getEngine() + if err != nil { + return nil, netip.Addr{}, err } addr, err := engine.Address() diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 990630ee4..4e22bde3f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -35,6 +35,12 @@ const ( ipTCPHeaderMinSize = 40 ) +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. EnvDisableConntrack = "NB_DISABLE_CONNTRACK" @@ -59,12 +65,6 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") -// serviceKey represents a protocol/port combination for netstack service registry -type serviceKey struct { - protocol gopacket.LayerType - port uint16 -} - // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index c56a078fc..120a9f418 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst return buf.Bytes() } + +func TestShouldForward(t *testing.T) { + // Set up test addresses + wgIP := netip.MustParseAddr("100.10.0.1") + otherIP := netip.MustParseAddr("100.10.0.2") + + // Create test manager with mock interface + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + // Set the mock to return our test WG IP + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)} + } + + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Helper to create decoder with TCP packet + createTCPDecoder := func(dstPort uint16) *decoder { + ipv4 := &layers.IPv4{ + Version: 4, + Protocol: layers.IPProtocolTCP, + SrcIP: net.ParseIP("192.168.1.100"), + DstIP: wgIP.AsSlice(), + } + tcp := &layers.TCP{ + SrcPort: 54321, + DstPort: layers.TCPPort(dstPort), + } + + err := tcp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")) + require.NoError(t, err) + + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + + err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + require.NoError(t, err) + + return d + } + + tests := []struct { + name string + localForwarding bool + netstack bool + dstIP netip.Addr + serviceRegistered bool + servicePort uint16 + expected bool + description string + }{ + { + name: "no local forwarding", + localForwarding: false, + netstack: true, + dstIP: wgIP, + expected: false, + description: "should never forward when local forwarding disabled", + }, + { + name: "traffic to other local interface", + localForwarding: true, + netstack: false, + dstIP: otherIP, + expected: true, + description: "should forward traffic to our other local interfaces (not NetBird IP)", + }, + { + name: "traffic to NetBird IP, no netstack", + localForwarding: true, + netstack: false, + dstIP: wgIP, + expected: false, + description: "should send to netstack listeners (final return false path)", + }, + { + name: "traffic to our IP, netstack mode, no service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + expected: true, + description: "should forward when in netstack mode with no matching service", + }, + { + name: "traffic to our IP, netstack mode, with service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + serviceRegistered: true, + servicePort: 22, + expected: false, + description: "should send to netstack listeners when service is registered", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Configure manager + manager.localForwarding = tt.localForwarding + manager.netstack = tt.netstack + + // Register service if needed + if tt.serviceRegistered { + manager.RegisterNetstackService(nftypes.TCP, tt.servicePort) + defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort) + } + + // Create decoder for the test + decoder := createTCPDecoder(tt.servicePort) + if !tt.serviceRegistered { + decoder = createTCPDecoder(8080) // Use non-registered port + } + + // Test the method + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/firewall/uspfilter/nat_stateful_test.go b/client/firewall/uspfilter/nat_stateful_test.go new file mode 100644 index 000000000..21c6da06e --- /dev/null +++ b/client/firewall/uspfilter/nat_stateful_test.go @@ -0,0 +1,85 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// TestPortDNATBasic tests basic port DNAT functionality +func TestPortDNATBasic(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rule for peer B (the target) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Scenario: Peer A connects to Peer B on port 22 (should get NAT) + packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d := parsePacket(t, packetAtoB) + translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB) + require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)") + + // Verify port was translated to 22022 + d = parsePacket(t, packetAtoB) + require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022") + + // Scenario: Return traffic from Peer B to Peer A should NOT be translated + // (prevents double NAT - original port stored in conntrack) + returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321) + d2 := parsePacket(t, returnPacket) + translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA) + require.False(t, translatedReturn, "Return traffic from same IP should not be translated") +} + +// TestPortDNATMultipleRules tests multiple port DNAT rules +func TestPortDNATMultipleRules(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rules for both peers + err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Test traffic to peer B gets translated + packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d1 := parsePacket(t, packetToB) + translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB) + require.True(t, translatedToB, "Traffic to peer B should be translated") + d1 = parsePacket(t, packetToB) + require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022") + + // Test traffic to peer A gets translated + packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) + d2 := parsePacket(t, packetToA) + translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA) + require.True(t, translatedToA, "Traffic to peer A should be translated") + d2 = parsePacket(t, packetToA) + require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022") +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 965decc73..dd6f9479a 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -17,7 +17,6 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" - "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules := networkMap.FirewallRules - enableSSH := networkMap.PeerConfig != nil && - networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled - - // If SSH enabled, add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). - if enableSSH { - rules = append(rules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: strconv.Itoa(ssh.DefaultSSHPort), - }) - } - // if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag // we have old version of management without rules handling, we should allow all traffic if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty { diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 638245bf7..4bc0fd800 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) { }) } } - -func TestDefaultManagerEnableSSHRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - PeerConfig: &mgmProto.PeerConfig{ - SshConfig: &mgmProto.SSHConfig{ - SshEnabled: true, - }, - }, - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ifaceMock := mocks.NewMockIFaceMapper(ctrl) - ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() - ifaceMock.EXPECT().SetFilter(gomock.Any()) - network := netip.MustParsePrefix("172.0.0.1/32") - - ifaceMock.EXPECT().Name().Return("lo").AnyTimes() - ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: network.Addr(), - Network: network, - }).AnyTimes() - ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) - require.NoError(t, err) - defer func() { - err = fw.Close(nil) - require.NoError(t, err) - }() - - acl := NewDefaultManager(fw) - - acl.ApplyFiltering(networkMap, false) - - expectedRules := 3 - if fw.IsStateful() { - expectedRules = 3 // 2 inbound rules + SSH rule - } - assert.Equal(t, expectedRules, len(acl.peerRulesPairs)) -} diff --git a/client/internal/connect.go b/client/internal/connect.go index bb7c2b38b..6ad5f264b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -416,20 +416,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf nm = *config.NetworkMonitor } engineConf := &EngineConfig{ - WgIfaceName: config.WgIface, - WgAddr: peerConfig.Address, - IFaceBlackList: config.IFaceBlackList, - DisableIPv6Discovery: config.DisableIPv6Discovery, - WgPrivateKey: key, - WgPort: config.WgPort, - NetworkMonitor: nm, - SSHKey: []byte(config.SSHKey), - NATExternalIPs: config.NATExternalIPs, - CustomDNSAddress: config.CustomDNSAddress, - RosenpassEnabled: config.RosenpassEnabled, - RosenpassPermissive: config.RosenpassPermissive, - ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), - DNSRouteInterval: config.DNSRouteInterval, + WgIfaceName: config.WgIface, + WgAddr: peerConfig.Address, + IFaceBlackList: config.IFaceBlackList, + DisableIPv6Discovery: config.DisableIPv6Discovery, + WgPrivateKey: key, + WgPort: config.WgPort, + NetworkMonitor: nm, + SSHKey: []byte(config.SSHKey), + NATExternalIPs: config.NATExternalIPs, + CustomDNSAddress: config.CustomDNSAddress, + RosenpassEnabled: config.RosenpassEnabled, + RosenpassPermissive: config.RosenpassPermissive, + ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), + EnableSSHRoot: config.EnableSSHRoot, + EnableSSHSFTP: config.EnableSSHSFTP, + EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding, + DisableSSHAuth: config.DisableSSHAuth, + DNSRouteInterval: config.DNSRouteInterval, DisableClientRoutes: config.DisableClientRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound, @@ -515,6 +520,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index fbec29ce3..58977b884 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) if g.internalConfig.ServerSSHAllowed != nil { configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) } + if g.internalConfig.EnableSSHRoot != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot)) + } + if g.internalConfig.EnableSSHSFTP != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP)) + } + if g.internalConfig.EnableSSHLocalPortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding)) + } + if g.internalConfig.EnableSSHRemotePortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) + } configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) diff --git a/client/internal/engine.go b/client/internal/engine.go index ebc05c453..1deb3d3cf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -9,7 +9,6 @@ import ( "net/netip" "net/url" "os" - "reflect" "runtime" "slices" "sort" @@ -30,7 +29,6 @@ import ( firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" - nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" @@ -51,10 +49,10 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" cProto "github.com/netbirdio/netbird/client/proto" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -115,7 +113,12 @@ type EngineConfig struct { RosenpassEnabled bool RosenpassPermissive bool - ServerSSHAllowed bool + ServerSSHAllowed bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool DNSRouteInterval time.Duration @@ -148,8 +151,6 @@ type Engine struct { // syncMsgMux is used to guarantee sequential Management Service message processing syncMsgMux *sync.Mutex - // sshMux protects sshServer field access - sshMux sync.Mutex config *EngineConfig mobileDep MobileDependency @@ -175,8 +176,7 @@ type Engine struct { networkMonitor *networkmonitor.NetworkMonitor - sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) - sshServer nbssh.Server + sshServer sshServer statusRecorder *peer.Status @@ -246,7 +246,6 @@ func NewEngine( STUNs: []*stun.URI{}, TURNs: []*stun.URI{}, networkSerial: 0, - sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), @@ -268,6 +267,7 @@ func NewEngine( path = mobileDep.StateFilePath } engine.stateManager = statemanager.New(path) + engine.stateManager.RegisterState(&sshconfig.ShutdownState{}) log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) return engine @@ -292,6 +292,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + if err := e.stopSSHServer(); err != nil { + log.Warnf("failed to stop SSH server: %v", err) + } + + e.cleanupSSHConfig() + // stop/restore DNS first so dbus and friends don't complain because of a missing interface e.stopDNSServer() @@ -703,16 +709,10 @@ func (e *Engine) removeAllPeers() error { return nil } -// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server +// removePeer closes an existing peer connection and removes a peer func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) - e.sshMux.Lock() - if !isNil(e.sshServer) { - e.sshServer.RemoveAuthorizedKey(peerKey) - } - e.sshMux.Unlock() - e.connMgr.RemovePeerConn(peerKey) err := e.statusRecorder.RemovePeer(peerKey) @@ -884,6 +884,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) if err := e.mgmClient.SyncMeta(info); err != nil { @@ -893,74 +898,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { return nil } -func isNil(server nbssh.Server) bool { - return server == nil || reflect.ValueOf(server).IsNil() -} - -func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { - if e.config.BlockInbound { - log.Infof("SSH server is disabled because inbound connections are blocked") - return nil - } - - if !e.config.ServerSSHAllowed { - log.Info("SSH server is not enabled") - return nil - } - - if sshConf.GetSshEnabled() { - if runtime.GOOS == "windows" { - log.Warnf("running SSH server on %s is not supported", runtime.GOOS) - return nil - } - e.sshMux.Lock() - // start SSH server if it wasn't running - if isNil(e.sshServer) { - listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) - if nbnetstack.IsEnabled() { - listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) - } - // nil sshServer means it has not yet been started - server, err := e.sshServerFunc(e.config.SSHKey, listenAddr) - if err != nil { - e.sshMux.Unlock() - return fmt.Errorf("create ssh server: %w", err) - } - - e.sshServer = server - e.sshMux.Unlock() - - go func() { - // blocking - err = server.Start() - if err != nil { - // will throw error when we stop it even if it is a graceful stop - log.Debugf("stopped SSH server with error %v", err) - } - e.sshMux.Lock() - e.sshServer = nil - e.sshMux.Unlock() - log.Infof("stopped SSH server") - }() - } else { - e.sshMux.Unlock() - log.Debugf("SSH server is already running") - } - } else { - e.sshMux.Lock() - if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) - } - e.sshServer = nil - } - e.sshMux.Unlock() - } - return nil -} - func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { if e.wgInterface == nil { return errors.New("wireguard interface is not initialized") @@ -973,8 +910,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { } if conf.GetSshConfig() != nil { - err := e.updateSSH(conf.GetSshConfig()) - if err != nil { + if err := e.updateSSH(conf.GetSshConfig()); err != nil { log.Warnf("failed handling SSH server setup: %v", err) } } @@ -1012,6 +948,11 @@ func (e *Engine) receiveManagementEvents() { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) @@ -1170,19 +1111,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() - // update SSHServer by adding remote peer SSH keys - e.sshMux.Lock() - if !isNil(e.sshServer) { - for _, config := range networkMap.GetRemotePeers() { - if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { - err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey())) - if err != nil { - log.Warnf("failed adding authorized key to SSH DefaultServer %v", err) - } - } - } + e.updatePeerSSHHostKeys(networkMap.GetRemotePeers()) + + if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { + log.Warnf("failed to update SSH client config: %v", err) } - e.sshMux.Unlock() } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1544,15 +1477,6 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } - e.sshMux.Lock() - if !isNil(e.sshServer) { - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed stopping the SSH server: %v", err) - } - } - e.sshMux.Unlock() - if e.firewall != nil { err := e.firewall.Close(e.stateManager) if err != nil { @@ -1583,6 +1507,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) netMap, err := e.mgmClient.GetNetworkMap(info) diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go new file mode 100644 index 000000000..861b3d6d2 --- /dev/null +++ b/client/internal/engine_ssh.go @@ -0,0 +1,355 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net/netip" + "strings" + + log "github.com/sirupsen/logrus" + + firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +type sshServer interface { + Start(ctx context.Context, addr netip.AddrPort) error + Stop() error + GetStatus() (bool, []sshserver.SessionInfo) +} + +func (e *Engine) setupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("add SSH port redirection: %w", err) + } + log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { + if e.config.BlockInbound { + log.Info("SSH server is disabled because inbound connections are blocked") + return e.stopSSHServer() + } + + if !e.config.ServerSSHAllowed { + log.Info("SSH server is disabled in config") + return e.stopSSHServer() + } + + if !sshConf.GetSshEnabled() { + if e.config.ServerSSHAllowed { + log.Info("SSH server is locally allowed but disabled by management server") + } + return e.stopSSHServer() + } + + if e.sshServer != nil { + log.Debug("SSH server is already running") + return nil + } + + if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth { + log.Info("starting SSH server without JWT authentication (authentication disabled by config)") + return e.startSSHServer(nil) + } + + if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil { + jwtConfig := &sshserver.JWTConfig{ + Issuer: protoJWT.GetIssuer(), + Audience: protoJWT.GetAudience(), + KeysLocation: protoJWT.GetKeysLocation(), + MaxTokenAge: protoJWT.GetMaxTokenAge(), + } + + return e.startSSHServer(jwtConfig) + } + + return errors.New("SSH server requires valid JWT configuration") +} + +// updateSSHClientConfig updates the SSH client configuration with peer information +func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { + peerInfo := e.extractPeerSSHInfo(remotePeers) + if len(peerInfo) == 0 { + log.Debug("no SSH-enabled peers found, skipping SSH config update") + return nil + } + + configMgr := sshconfig.New() + if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil { + log.Warnf("failed to update SSH client config: %v", err) + return nil // Don't fail engine startup on SSH config issues + } + + log.Debugf("updated SSH client config with %d peers", len(peerInfo)) + + if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{ + SSHConfigDir: configMgr.GetSSHConfigDir(), + SSHConfigFile: configMgr.GetSSHConfigFile(), + }); err != nil { + log.Warnf("failed to update SSH config state: %v", err) + } + + return nil +} + +// extractPeerSSHInfo extracts SSH information from peer configurations +func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo { + var peerInfo []sshconfig.PeerSSHInfo + + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + peerIP := e.extractPeerIP(peerConfig) + hostname := e.extractHostname(peerConfig) + + peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ + Hostname: hostname, + IP: peerIP, + FQDN: peerConfig.GetFqdn(), + }) + } + + return peerInfo +} + +// extractPeerIP extracts IP address from peer's allowed IPs +func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { + if len(peerConfig.GetAllowedIps()) == 0 { + return "" + } + + if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { + return prefix.Addr().String() + } + return "" +} + +// extractHostname extracts short hostname from FQDN +func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { + fqdn := peerConfig.GetFqdn() + if fqdn == "" { + return "" + } + + parts := strings.Split(fqdn, ".") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} + +// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access +func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) { + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil { + log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err) + } + } + + log.Debugf("updated peer SSH host keys for daemon API access") +} + +// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN +func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { + e.syncMsgMux.Lock() + statusRecorder := e.statusRecorder + e.syncMsgMux.Unlock() + + if statusRecorder == nil { + return nil, false + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if len(peerState.SSHHostKey) > 0 { + return peerState.SSHHostKey, true + } + return nil, false + } + } + + return nil, false +} + +// cleanupSSHConfig removes NetBird SSH client configuration on shutdown +func (e *Engine) cleanupSSHConfig() { + configMgr := sshconfig.New() + + if err := configMgr.RemoveSSHClientConfig(); err != nil { + log.Warnf("failed to remove SSH client config: %v", err) + } else { + log.Debugf("SSH client config cleanup completed") + } +} + +// startSSHServer initializes and starts the SSH server with proper configuration. +func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { + if e.wgInterface == nil { + return errors.New("wg interface not initialized") + } + + serverConfig := &sshserver.Config{ + HostKeyPEM: e.config.SSHKey, + JWT: jwtConfig, + } + server := sshserver.New(serverConfig) + + wgAddr := e.wgInterface.Address() + server.SetNetworkValidation(wgAddr) + + netbirdIP := wgAddr.IP + listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort) + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + server.SetNetstackNet(netstackNet) + } + + e.configureSSHServer(server) + + if err := server.Start(e.ctx, listenAddr); err != nil { + return fmt.Errorf("start SSH server: %w", err) + } + + e.sshServer = server + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + if err := e.setupSSHPortRedirection(); err != nil { + log.Warnf("failed to setup SSH port redirection: %v", err) + } + + return nil +} + +// configureSSHServer applies SSH configuration options to the server. +func (e *Engine) configureSSHServer(server *sshserver.Server) { + if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot { + server.SetAllowRootLogin(true) + log.Info("SSH root login enabled") + } else { + server.SetAllowRootLogin(false) + log.Info("SSH root login disabled (default)") + } + + if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP { + server.SetAllowSFTP(true) + log.Info("SSH SFTP subsystem enabled") + } else { + server.SetAllowSFTP(false) + log.Info("SSH SFTP subsystem disabled (default)") + } + + if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding { + server.SetAllowLocalPortForwarding(true) + log.Info("SSH local port forwarding enabled") + } else { + server.SetAllowLocalPortForwarding(false) + log.Info("SSH local port forwarding disabled (default)") + } + + if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding { + server.SetAllowRemotePortForwarding(true) + log.Info("SSH remote port forwarding enabled") + } else { + server.SetAllowRemotePortForwarding(false) + log.Info("SSH remote port forwarding disabled (default)") + } +} + +func (e *Engine) cleanupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("remove SSH port redirection: %w", err) + } + log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) stopSSHServer() error { + if e.sshServer == nil { + return nil + } + + if err := e.cleanupSSHPortRedirection(); err != nil { + log.Warnf("failed to cleanup SSH port redirection: %v", err) + } + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + log.Info("stopping SSH server") + err := e.sshServer.Stop() + e.sshServer = nil + if err != nil { + return fmt.Errorf("stop: %w", err) + } + return nil +} + +// GetSSHServerStatus returns the SSH server status and active sessions +func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) { + e.syncMsgMux.Lock() + sshServer := e.sshServer + e.syncMsgMux.Unlock() + + if sshServer == nil { + return false, nil + } + + return sshServer.GetStatus() +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d15a07f9d..3b7ff0eba 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,6 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,7 +24,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -46,7 +48,7 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/routemanager" - "github.com/netbirdio/netbird/client/ssh" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" @@ -214,11 +216,13 @@ func TestMain(m *testing.M) { } func TestEngine_SSH(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipping TestEngine_SSH") + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return } - key, err := wgtypes.GeneratePrivateKey() + sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { t.Fatal(err) return @@ -240,6 +244,7 @@ func TestEngine_SSH(t *testing.T) { WgPort: 33100, ServerSSHAllowed: true, MTU: iface.DefaultMTU, + SSHKey: sshKey, }, MobileDependency{}, peer.NewRecorder("https://mgm"), @@ -250,35 +255,8 @@ func TestEngine_SSH(t *testing.T) { UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } - var sshKeysAdded []string - var sshPeersRemoved []string - - sshCtx, cancel := context.WithCancel(context.Background()) - - engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) { - return &ssh.MockServer{ - Ctx: sshCtx, - StopFunc: func() error { - cancel() - return nil - }, - StartFunc: func() error { - <-ctx.Done() - return ctx.Err() - }, - AddAuthorizedKeyFunc: func(peer, newKey string) error { - sshKeysAdded = append(sshKeysAdded, newKey) - return nil - }, - RemoveAuthorizedKeyFunc: func(peer string) { - sshPeersRemoved = append(sshPeersRemoved, peer) - }, - }, nil - } err = engine.Start(nil, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer func() { err := engine.Stop() @@ -304,9 +282,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) @@ -314,19 +290,24 @@ func TestEngine_SSH(t *testing.T) { networkMap = &mgmtProto.NetworkMap{ Serial: 7, PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24", - SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}}, + SshConfig: &mgmtProto.SSHConfig{ + SshEnabled: true, + JwtConfig: &mgmtProto.JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + MaxTokenAge: 3600, + }, + }}, RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH}, RemotePeersIsEmpty: false, } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ") // now remove peer networkMap = &mgmtProto.NetworkMap{ @@ -336,13 +317,10 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") // now disable SSH server networkMap = &mgmtProto.NetworkMap{ @@ -354,12 +332,70 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) +} +func TestEngine_SSHUpdateLogic(t *testing.T) { + // Test that SSH server start/stop logic works based on config + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, // Start with SSH disabled + }, + syncMsgMux: &sync.Mutex{}, + } + + // Test SSH disabled config + sshConfig := &mgmtProto.SSHConfig{SshEnabled: false} + err := engine.updateSSH(sshConfig) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + + // Test inbound blocked + engine.config.BlockInbound = true + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + engine.config.BlockInbound = false + + // Test with server SSH not allowed + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) +} + +func TestEngine_SSHServerConsistency(t *testing.T) { + + t.Run("server set only on successful creation", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: true, + SSHKey: []byte("test-key"), + }, + syncMsgMux: &sync.Mutex{}, + } + + engine.wgInterface = nil + + err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + + assert.Error(t, err) + assert.Nil(t, engine.sshServer) + }) + + t.Run("cleanup handles nil gracefully", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, + }, + syncMsgMux: &sync.Mutex{}, + } + + err := engine.stopSSHServer() + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + }) } func TestEngine_UpdateNetworkMap(t *testing.T) { @@ -1589,7 +1625,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } diff --git a/client/internal/login.go b/client/internal/login.go index 257e3c3ac..f528783ef 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) return serverKey, loginResp, err @@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 68afe986a..426c31e1a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { } }() - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { + if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { return false } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 32a458d00..7f500c410 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -2,6 +2,7 @@ package peer import ( "os" + "runtime" "strings" ) @@ -10,5 +11,8 @@ const ( ) func isForceRelayed() bool { + if runtime.GOOS == "js" { + return true + } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 239cce7e0..76f4f523c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,9 +21,9 @@ import ( "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" relayClient "github.com/netbirdio/netbird/shared/relay/client" - "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -67,6 +67,7 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool + SSHHostKey []byte routes map[string]struct{} } @@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { return nil } +// UpdatePeerSSHHostKey updates peer's SSH host key +func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peerPubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.SSHHostKey = sshHostKey + d.peers[peerPubKey] = peerState + + return nil +} + // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index f03822089..8f467a214 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{ // ConfigInput carries configuration changes to the client type ConfigInput struct { - ManagementURL string - AdminURL string - ConfigPath string - StateFilePath string - PreSharedKey *string - ServerSSHAllowed *bool - NATExternalIPs []string - CustomDNSAddress []byte - RosenpassEnabled *bool - RosenpassPermissive *bool - InterfaceName *string - WireguardPort *int - NetworkMonitor *bool - DisableAutoConnect *bool - ExtraIFaceBlackList []string - DNSRouteInterval *time.Duration - ClientCertPath string - ClientCertKeyPath string + ManagementURL string + AdminURL string + ConfigPath string + StateFilePath string + PreSharedKey *string + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int + NATExternalIPs []string + CustomDNSAddress []byte + RosenpassEnabled *bool + RosenpassPermissive *bool + InterfaceName *string + WireguardPort *int + NetworkMonitor *bool + DisableAutoConnect *bool + ExtraIFaceBlackList []string + DNSRouteInterval *time.Duration + ClientCertPath string + ClientCertKeyPath string DisableClientRoutes *bool DisableServerRoutes *bool @@ -82,18 +88,24 @@ type ConfigInput struct { // Config Configuration type type Config struct { // Wireguard private key of local peer - PrivateKey string - PreSharedKey string - ManagementURL *url.URL - AdminURL *url.URL - WgIface string - WgPort int - NetworkMonitor *bool - IFaceBlackList []string - DisableIPv6Discovery bool - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed *bool + PrivateKey string + PreSharedKey string + ManagementURL *url.URL + AdminURL *url.URL + WgIface string + WgPort int + NetworkMonitor *bool + IFaceBlackList []string + DisableIPv6Discovery bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int DisableClientRoutes bool DisableServerRoutes bool @@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot { + if *input.EnableSSHRoot { + log.Infof("enabling SSH root login") + } else { + log.Infof("disabling SSH root login") + } + config.EnableSSHRoot = input.EnableSSHRoot + updated = true + } + + if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP { + if *input.EnableSSHSFTP { + log.Infof("enabling SSH SFTP subsystem") + } else { + log.Infof("disabling SSH SFTP subsystem") + } + config.EnableSSHSFTP = input.EnableSSHSFTP + updated = true + } + + if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding { + if *input.EnableSSHLocalPortForwarding { + log.Infof("enabling SSH local port forwarding") + } else { + log.Infof("disabling SSH local port forwarding") + } + config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding + updated = true + } + + if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding { + if *input.EnableSSHRemotePortForwarding { + log.Infof("enabling SSH remote port forwarding") + } else { + log.Infof("disabling SSH remote port forwarding") + } + config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding + updated = true + } + + if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth { + if *input.DisableSSHAuth { + log.Infof("disabling SSH authentication") + } else { + log.Infof("enabling SSH authentication") + } + config.DisableSSHAuth = input.DisableSSHAuth + updated = true + } + + if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL { + log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL) + config.SSHJWTCacheTTL = input.SSHJWTCacheTTL + updated = true + } + if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { log.Infof("updating DNS route interval to %s (old value %s)", input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 90bde7707..ab13cf389 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) { func TestWireguardPortDefaultVsExplicit(t *testing.T) { tests := []struct { - name string - wireguardPort *int - expectedPort int - description string + name string + wireguardPort *int + expectedPort int + description string }{ { name: "no port specified uses default", diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go index fe0afae2b..c87f521cb 100644 --- a/client/internal/profilemanager/profilemanager.go +++ b/client/internal/profilemanager/profilemanager.go @@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error { return nil } + +// GetLoginHint retrieves the email from the active profile to use as login_hint. +func GetLoginHint() string { + pm := NewProfileManager() + activeProf, err := pm.GetActiveProfile() + if err != nil { + log.Debugf("failed to get active profile for login hint: %v", err) + return "" + } + + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + return "" + } + + return profileState.Email +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 587e05c74..8d1398a7a 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -18,8 +18,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const ( diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 26cf758d9..2baa0e668 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -39,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/client/net" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/version" diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index fa1c89aab..b0d377c21 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 02f09b08a..7b9ae25f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -137,7 +137,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 0} + return file_daemon_proto_rawDescGZIP(), []int{51, 0} } type SystemEvent_Category int32 @@ -192,7 +192,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 1} + return file_daemon_proto_rawDescGZIP(), []int{51, 1} } type EmptyRequest struct { @@ -280,9 +280,15 @@ type LoginRequest struct { Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` // hint is used to pre-fill the email/username field during SSO authentication - Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,34,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,35,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -547,6 +553,48 @@ func (x *LoginRequest) GetHint() string { return "" } +func (x *LoginRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *LoginRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *LoginRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *LoginRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *LoginRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -1057,24 +1105,30 @@ type GetConfigResponse struct { // preSharedKey settings value. PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"` // adminURL settings value. - AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` - InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` - WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` - Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` - DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` - BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` - DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` - DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` - BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` + WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` + Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` + DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` + DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` + DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` + BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` + EnableSSHRoot bool `protobuf:"varint,21,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetConfigResponse) Reset() { @@ -1247,6 +1301,48 @@ func (x *GetConfigResponse) GetBlockLanAccess() bool { return false } +func (x *GetConfigResponse) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *GetConfigResponse) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + +func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 { + if x != nil { + return x.SshJWTCacheTTL + } + return 0 +} + // PeerState contains the latest state of a peer type PeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1267,6 +1363,7 @@ type PeerState struct { Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` + SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1420,6 +1517,13 @@ func (x *PeerState) GetRelayAddress() string { return "" } +func (x *PeerState) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1764,6 +1868,128 @@ func (x *NSGroupState) GetError() string { return "" } +// SSHSessionInfo contains information about an active SSH session +type SSHSessionInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` + Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` + JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHSessionInfo) Reset() { + *x = SSHSessionInfo{} + mi := &file_daemon_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHSessionInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHSessionInfo) ProtoMessage() {} + +func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. +func (*SSHSessionInfo) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{19} +} + +func (x *SSHSessionInfo) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *SSHSessionInfo) GetRemoteAddress() string { + if x != nil { + return x.RemoteAddress + } + return "" +} + +func (x *SSHSessionInfo) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +func (x *SSHSessionInfo) GetJwtUsername() string { + if x != nil { + return x.JwtUsername + } + return "" +} + +// SSHServerState contains the latest state of the SSH server +type SSHServerState struct { + state protoimpl.MessageState `protogen:"open.v1"` + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` + Sessions []*SSHSessionInfo `protobuf:"bytes,2,rep,name=sessions,proto3" json:"sessions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHServerState) Reset() { + *x = SSHServerState{} + mi := &file_daemon_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHServerState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHServerState) ProtoMessage() {} + +func (x *SSHServerState) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. +func (*SSHServerState) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{20} +} + +func (x *SSHServerState) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *SSHServerState) GetSessions() []*SSHSessionInfo { + if x != nil { + return x.Sessions + } + return nil +} + // FullStatus contains the full state held by the Status instance type FullStatus struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1776,13 +2002,14 @@ type FullStatus struct { NumberOfForwardingRules int32 `protobuf:"varint,8,opt,name=NumberOfForwardingRules,proto3" json:"NumberOfForwardingRules,omitempty"` Events []*SystemEvent `protobuf:"bytes,7,rep,name=events,proto3" json:"events,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,9,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + SshServerState *SSHServerState `protobuf:"bytes,10,opt,name=sshServerState,proto3" json:"sshServerState,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1794,7 +2021,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1807,7 +2034,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -1873,6 +2100,13 @@ func (x *FullStatus) GetLazyConnectionEnabled() bool { return false } +func (x *FullStatus) GetSshServerState() *SSHServerState { + if x != nil { + return x.SshServerState + } + return nil +} + // Networks type ListNetworksRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1882,7 +2116,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1894,7 +2128,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1907,7 +2141,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{22} } type ListNetworksResponse struct { @@ -1919,7 +2153,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1931,7 +2165,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1944,7 +2178,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -1965,7 +2199,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1977,7 +2211,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1990,7 +2224,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{24} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2022,7 +2256,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2034,7 +2268,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2047,7 +2281,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{25} } type IPList struct { @@ -2059,7 +2293,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2071,7 +2305,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2084,7 +2318,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *IPList) GetIps() []string { @@ -2107,7 +2341,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2119,7 +2353,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2132,7 +2366,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{27} } func (x *Network) GetID() string { @@ -2184,7 +2418,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2196,7 +2430,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2209,7 +2443,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2266,7 +2500,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2278,7 +2512,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2291,7 +2525,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *ForwardingRule) GetProtocol() string { @@ -2338,7 +2572,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2350,7 +2584,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2363,7 +2597,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2387,7 +2621,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2399,7 +2633,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2412,7 +2646,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2461,7 +2695,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2473,7 +2707,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2486,7 +2720,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *DebugBundleResponse) GetPath() string { @@ -2518,7 +2752,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2530,7 +2764,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2543,7 +2777,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{33} } type GetLogLevelResponse struct { @@ -2555,7 +2789,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2567,7 +2801,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2580,7 +2814,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -2599,7 +2833,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2611,7 +2845,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2624,7 +2858,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{35} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -2642,7 +2876,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2654,7 +2888,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2667,7 +2901,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{36} } // State represents a daemon state entry @@ -2680,7 +2914,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2692,7 +2926,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2705,7 +2939,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *State) GetName() string { @@ -2724,7 +2958,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2736,7 +2970,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2749,7 +2983,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{38} } // ListStatesResponse contains a list of states @@ -2762,7 +2996,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2774,7 +3008,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2787,7 +3021,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *ListStatesResponse) GetStates() []*State { @@ -2808,7 +3042,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2820,7 +3054,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2833,7 +3067,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{40} } func (x *CleanStateRequest) GetStateName() string { @@ -2860,7 +3094,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2872,7 +3106,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2885,7 +3119,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -2906,7 +3140,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2918,7 +3152,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2931,7 +3165,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *DeleteStateRequest) GetStateName() string { @@ -2958,7 +3192,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2970,7 +3204,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2983,7 +3217,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3002,7 +3236,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3014,7 +3248,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3027,7 +3261,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3045,7 +3279,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3057,7 +3291,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3070,7 +3304,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{45} } type TCPFlags struct { @@ -3087,7 +3321,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3099,7 +3333,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3112,7 +3346,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *TCPFlags) GetSyn() bool { @@ -3174,7 +3408,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3186,7 +3420,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3199,7 +3433,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{47} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3277,7 +3511,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3289,7 +3523,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3302,7 +3536,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TraceStage) GetName() string { @@ -3343,7 +3577,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3355,7 +3589,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3368,7 +3602,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3393,7 +3627,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3405,7 +3639,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3418,7 +3652,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{50} } type SystemEvent struct { @@ -3436,7 +3670,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3448,7 +3682,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3461,7 +3695,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *SystemEvent) GetId() string { @@ -3521,7 +3755,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3533,7 +3767,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3546,7 +3780,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{52} } type GetEventsResponse struct { @@ -3558,7 +3792,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3570,7 +3804,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3583,7 +3817,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -3603,7 +3837,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3615,7 +3849,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3628,7 +3862,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{54} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -3653,7 +3887,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3665,7 +3899,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3678,7 +3912,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{55} } type SetConfigRequest struct { @@ -3711,16 +3945,22 @@ type SetConfigRequest struct { ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"` // cleanDNSLabels clean map list of DNS labels. - CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` - DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` - Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` + DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,29,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3732,7 +3972,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3745,7 +3985,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SetConfigRequest) GetUsername() string { @@ -3944,6 +4184,48 @@ func (x *SetConfigRequest) GetMtu() int64 { return 0 } +func (x *SetConfigRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *SetConfigRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -3952,7 +4234,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3964,7 +4246,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3977,7 +4259,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{57} } type AddProfileRequest struct { @@ -3990,7 +4272,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4002,7 +4284,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4015,7 +4297,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *AddProfileRequest) GetUsername() string { @@ -4040,7 +4322,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4052,7 +4334,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4065,7 +4347,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{59} } type RemoveProfileRequest struct { @@ -4078,7 +4360,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4090,7 +4372,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4103,7 +4385,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4128,7 +4410,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4140,7 +4422,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4153,7 +4435,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{61} } type ListProfilesRequest struct { @@ -4165,7 +4447,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4177,7 +4459,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4190,7 +4472,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *ListProfilesRequest) GetUsername() string { @@ -4209,7 +4491,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4221,7 +4503,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4234,7 +4516,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{63} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4254,7 +4536,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4266,7 +4548,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4279,7 +4561,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *Profile) GetName() string { @@ -4304,7 +4586,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4316,7 +4598,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4329,7 +4611,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{65} } type GetActiveProfileResponse struct { @@ -4342,7 +4624,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4354,7 +4636,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4367,7 +4649,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4394,7 +4676,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4406,7 +4688,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4419,7 +4701,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{67} } func (x *LogoutRequest) GetProfileName() string { @@ -4444,7 +4726,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4456,7 +4738,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4469,7 +4751,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{68} } type GetFeaturesRequest struct { @@ -4480,7 +4762,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4492,7 +4774,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4505,7 +4787,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{69} } type GetFeaturesResponse struct { @@ -4518,7 +4800,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4530,7 +4812,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4543,7 +4825,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{70} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -4560,6 +4842,390 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +type GetPeerSSHHostKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // peer IP address or FQDN to get SSH host key for + PeerAddress string `protobuf:"bytes,1,opt,name=peerAddress,proto3" json:"peerAddress,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyRequest) Reset() { + *x = GetPeerSSHHostKeyRequest{} + mi := &file_daemon_proto_msgTypes[71] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[71] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{71} +} + +func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { + if x != nil { + return x.PeerAddress + } + return "" +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +type GetPeerSSHHostKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + SshHostKey []byte `protobuf:"bytes,1,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` + // peer IP address + PeerIP string `protobuf:"bytes,2,opt,name=peerIP,proto3" json:"peerIP,omitempty"` + // peer FQDN + PeerFQDN string `protobuf:"bytes,3,opt,name=peerFQDN,proto3" json:"peerFQDN,omitempty"` + // indicates if the SSH host key was found + Found bool `protobuf:"varint,4,opt,name=found,proto3" json:"found,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyResponse) Reset() { + *x = GetPeerSSHHostKeyResponse{} + mi := &file_daemon_proto_msgTypes[72] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[72] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{72} +} + +func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerIP() string { + if x != nil { + return x.PeerIP + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerFQDN() string { + if x != nil { + return x.PeerFQDN + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetFound() bool { + if x != nil { + return x.Found + } + return false +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +type RequestJWTAuthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // hint for OIDC login_hint parameter (typically email address) + Hint *string `protobuf:"bytes,1,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthRequest) Reset() { + *x = RequestJWTAuthRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthRequest) ProtoMessage() {} + +func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +func (x *RequestJWTAuthRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + +// RequestJWTAuthResponse contains authentication flow information +type RequestJWTAuthResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // verification URI for user authentication + VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"` + // complete verification URI (with embedded user code) + VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"` + // user code to enter on verification URI + UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"` + // device code for polling + DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + // if a cached token is available, it will be returned here + CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"` + // maximum age of JWT tokens in seconds (from management server) + MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthResponse) Reset() { + *x = RequestJWTAuthResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthResponse) ProtoMessage() {} + +func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *RequestJWTAuthResponse) GetVerificationURI() string { + if x != nil { + return x.VerificationURI + } + return "" +} + +func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string { + if x != nil { + return x.VerificationURIComplete + } + return "" +} + +func (x *RequestJWTAuthResponse) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +func (x *RequestJWTAuthResponse) GetCachedToken() string { + if x != nil { + return x.CachedToken + } + return "" +} + +func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + +// WaitJWTTokenRequest for waiting for authentication completion +type WaitJWTTokenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // device code from RequestJWTAuthResponse + DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // user code for verification + UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenRequest) Reset() { + *x = WaitJWTTokenRequest{} + mi := &file_daemon_proto_msgTypes[75] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenRequest) ProtoMessage() {} + +func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[75] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{75} +} + +func (x *WaitJWTTokenRequest) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *WaitJWTTokenRequest) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +// WaitJWTTokenResponse contains the JWT token after authentication +type WaitJWTTokenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // JWT token (access token or ID token) + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + // token type (e.g., "Bearer") + TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenResponse) Reset() { + *x = WaitJWTTokenResponse{} + mi := &file_daemon_proto_msgTypes[76] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenResponse) ProtoMessage() {} + +func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[76] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{76} +} + +func (x *WaitJWTTokenResponse) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *WaitJWTTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -4570,7 +5236,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4582,7 +5248,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4595,7 +5261,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26, 0} + return file_daemon_proto_rawDescGZIP(), []int{28, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -4617,7 +5283,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xe5\x0e\n" + + "\fEmptyRequest\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4655,7 +5321,13 @@ const file_daemon_proto_rawDesc = "" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + - "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\" \x01(\bH\x15R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18# \x01(\bH\x16R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4676,7 +5348,13 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + "\x04_mtuB\a\n" + - "\x05_hint\"\xb5\x01\n" + + "\x05_hintB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -4709,7 +5387,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -4734,7 +5412,13 @@ const file_daemon_proto_rawDesc = "" + "disableDns\x122\n" + "\x15disable_client_routes\x18\x12 \x01(\bR\x13disableClientRoutes\x122\n" + "\x15disable_server_routes\x18\x13 \x01(\bR\x13disableServerRoutes\x12(\n" + - "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\"\xde\x05\n" + + "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12$\n" + + "\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" + + "\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" + + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" + + "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" + + "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" + "\tPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + @@ -4755,7 +5439,10 @@ const file_daemon_proto_rawDesc = "" + "\x10rosenpassEnabled\x18\x0f \x01(\bR\x10rosenpassEnabled\x12\x1a\n" + "\bnetworks\x18\x10 \x03(\tR\bnetworks\x123\n" + "\alatency\x18\x11 \x01(\v2\x19.google.protobuf.DurationR\alatency\x12\"\n" + - "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\"\xf0\x01\n" + + "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x13 \x01(\fR\n" + + "sshHostKey\"\xf0\x01\n" + "\x0eLocalPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" + @@ -4781,7 +5468,15 @@ const file_daemon_proto_rawDesc = "" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + - "\x05error\x18\x04 \x01(\tR\x05error\"\xef\x03\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + + "\x0eSSHSessionInfo\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12$\n" + + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + + "\x0eSSHServerState\x12\x18\n" + + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + "\n" + "FullStatus\x12A\n" + "\x0fmanagementState\x18\x01 \x01(\v2\x17.daemon.ManagementStateR\x0fmanagementState\x125\n" + @@ -4793,7 +5488,9 @@ const file_daemon_proto_rawDesc = "" + "dnsServers\x128\n" + "\x17NumberOfForwardingRules\x18\b \x01(\x05R\x17NumberOfForwardingRules\x12+\n" + "\x06events\x18\a \x03(\v2\x13.daemon.SystemEventR\x06events\x124\n" + - "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\"\x15\n" + + "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\x12>\n" + + "\x0esshServerState\x18\n" + + " \x01(\v2\x16.daemon.SSHServerStateR\x0esshServerState\"\x15\n" + "\x13ListNetworksRequest\"?\n" + "\x14ListNetworksResponse\x12'\n" + "\x06routes\x18\x01 \x03(\v2\x0f.daemon.NetworkR\x06routes\"a\n" + @@ -4934,7 +5631,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\x8e\r\n" + + "\x15SwitchProfileResponse\"\xdf\x10\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -4967,7 +5664,13 @@ const file_daemon_proto_rawDesc = "" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4985,7 +5688,13 @@ const file_daemon_proto_rawDesc = "" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x13\n" + "\x11_dnsRouteIntervalB\x06\n" + - "\x04_mtu\"\x13\n" + + "\x04_mtuB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -5015,7 +5724,38 @@ const file_daemon_proto_rawDesc = "" + "\x12GetFeaturesRequest\"x\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x18GetPeerSSHHostKeyRequest\x12 \n" + + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x01 \x01(\fR\n" + + "sshHostKey\x12\x16\n" + + "\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" + + "\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" + + "\x05found\x18\x04 \x01(\bR\x05found\"9\n" + + "\x15RequestJWTAuthRequest\x12\x17\n" + + "\x04hint\x18\x01 \x01(\tH\x00R\x04hint\x88\x01\x01B\a\n" + + "\x05_hint\"\x9a\x02\n" + + "\x16RequestJWTAuthResponse\x12(\n" + + "\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" + + "\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" + + "\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" + + "\n" + + "deviceCode\x18\x04 \x01(\tR\n" + + "deviceCode\x12\x1c\n" + + "\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" + + "\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" + + "\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" + + "\x13WaitJWTTokenRequest\x12\x1e\n" + + "\n" + + "deviceCode\x18\x01 \x01(\tR\n" + + "deviceCode\x12\x1a\n" + + "\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" + + "\x14WaitJWTTokenResponse\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -5024,7 +5764,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x8f\x10\n" + + "\x05TRACE\x10\a2\x8b\x12\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5056,7 +5796,10 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5071,7 +5814,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 80) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity @@ -5095,155 +5838,171 @@ var file_daemon_proto_goTypes = []any{ (*ManagementState)(nil), // 19: daemon.ManagementState (*RelayState)(nil), // 20: daemon.RelayState (*NSGroupState)(nil), // 21: daemon.NSGroupState - (*FullStatus)(nil), // 22: daemon.FullStatus - (*ListNetworksRequest)(nil), // 23: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 24: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 25: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 26: daemon.SelectNetworksResponse - (*IPList)(nil), // 27: daemon.IPList - (*Network)(nil), // 28: daemon.Network - (*PortInfo)(nil), // 29: daemon.PortInfo - (*ForwardingRule)(nil), // 30: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 31: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 32: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 33: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 34: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 35: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 36: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 37: daemon.SetLogLevelResponse - (*State)(nil), // 38: daemon.State - (*ListStatesRequest)(nil), // 39: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 40: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 41: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 42: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 43: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 44: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 45: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 46: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 47: daemon.TCPFlags - (*TracePacketRequest)(nil), // 48: daemon.TracePacketRequest - (*TraceStage)(nil), // 49: daemon.TraceStage - (*TracePacketResponse)(nil), // 50: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 51: daemon.SubscribeRequest - (*SystemEvent)(nil), // 52: daemon.SystemEvent - (*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 55: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 56: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 57: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 58: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 59: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 60: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 61: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 62: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 63: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 64: daemon.ListProfilesResponse - (*Profile)(nil), // 65: daemon.Profile - (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 68: daemon.LogoutRequest - (*LogoutResponse)(nil), // 69: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse - nil, // 72: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range - nil, // 74: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 75: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp + (*SSHSessionInfo)(nil), // 22: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 23: daemon.SSHServerState + (*FullStatus)(nil), // 24: daemon.FullStatus + (*ListNetworksRequest)(nil), // 25: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 26: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 27: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 28: daemon.SelectNetworksResponse + (*IPList)(nil), // 29: daemon.IPList + (*Network)(nil), // 30: daemon.Network + (*PortInfo)(nil), // 31: daemon.PortInfo + (*ForwardingRule)(nil), // 32: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 33: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 34: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 35: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 36: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 37: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 38: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 39: daemon.SetLogLevelResponse + (*State)(nil), // 40: daemon.State + (*ListStatesRequest)(nil), // 41: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 42: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 43: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 44: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 45: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 46: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 47: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 48: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 49: daemon.TCPFlags + (*TracePacketRequest)(nil), // 50: daemon.TracePacketRequest + (*TraceStage)(nil), // 51: daemon.TraceStage + (*TracePacketResponse)(nil), // 52: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 53: daemon.SubscribeRequest + (*SystemEvent)(nil), // 54: daemon.SystemEvent + (*GetEventsRequest)(nil), // 55: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 56: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 57: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 58: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 59: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 60: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 61: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 62: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 63: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 64: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 65: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 66: daemon.ListProfilesResponse + (*Profile)(nil), // 67: daemon.Profile + (*GetActiveProfileRequest)(nil), // 68: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 69: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 70: daemon.LogoutRequest + (*LogoutResponse)(nil), // 71: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 72: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 73: daemon.GetFeaturesResponse + (*GetPeerSSHHostKeyRequest)(nil), // 74: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 75: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 76: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 77: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 78: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 79: daemon.WaitJWTTokenResponse + nil, // 80: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 81: daemon.PortInfo.Range + nil, // 82: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 83: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 84: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 16, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState - 20, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState - 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 18: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 19: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 38, // 20: daemon.ListStatesResponse.states:type_name -> daemon.State - 47, // 21: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 32: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 33: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 34: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 35: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 36: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 23, // 37: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 25, // 38: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 25, // 39: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 40: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 32, // 41: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 34, // 42: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 36, // 43: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 39, // 44: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 41, // 45: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 43, // 46: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 45, // 47: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 48, // 48: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 51, // 49: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 53, // 50: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 55, // 51: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 57, // 52: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 59, // 53: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 59, // [59:87] is the sub-list for method output_type - 31, // [31:59] is the sub-list for method input_type - 31, // [31:31] is the sub-list for extension type_name - 31, // [31:31] is the sub-list for extension extendee - 0, // [0:31] is the sub-list for field type_name + 83, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 24, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 84, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 84, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 83, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 22, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 19, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 18, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 17, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 16, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState + 20, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState + 21, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 54, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 23, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 30, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 80, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 81, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 31, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 31, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 32, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 40, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State + 49, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 51, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 1, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 2, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 84, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 82, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 54, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 83, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 67, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 29, // 32: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 4, // 33: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 6, // 34: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 8, // 35: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 10, // 36: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 12, // 37: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 14, // 38: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 25, // 39: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 27, // 40: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 27, // 41: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 3, // 42: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 34, // 43: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 36, // 44: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 38, // 45: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 41, // 46: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 43, // 47: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 45, // 48: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 47, // 49: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 50, // 50: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 53, // 51: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 55, // 52: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 57, // 53: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 59, // 54: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 61, // 55: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 63, // 56: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 65, // 57: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 68, // 58: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 70, // 59: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 72, // 60: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 74, // 61: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 76, // 62: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 78, // 63: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 5, // 64: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 7, // 65: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 9, // 66: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 11, // 67: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 13, // 68: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 15, // 69: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 26, // 70: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 28, // 71: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 28, // 72: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 33, // 73: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 35, // 74: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 37, // 75: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 39, // 76: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 42, // 77: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 44, // 78: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 46, // 79: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 48, // 80: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 52, // 81: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 54, // 82: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 56, // 83: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 58, // 84: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 60, // 85: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 62, // 86: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 64, // 87: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 66, // 88: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 69, // 89: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 71, // 90: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 73, // 91: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 75, // 92: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 77, // 93: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 79, // 94: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 64, // [64:95] is the sub-list for method output_type + 33, // [33:64] is the sub-list for method input_type + 33, // [33:33] is the sub-list for extension type_name + 33, // [33:33] is the sub-list for extension extendee + 0, // [0:33] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -5254,22 +6013,23 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[26].OneofWrappers = []any{ + file_daemon_proto_msgTypes[28].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[45].OneofWrappers = []any{} - file_daemon_proto_msgTypes[46].OneofWrappers = []any{} - file_daemon_proto_msgTypes[52].OneofWrappers = []any{} + file_daemon_proto_msgTypes[47].OneofWrappers = []any{} + file_daemon_proto_msgTypes[48].OneofWrappers = []any{} file_daemon_proto_msgTypes[54].OneofWrappers = []any{} - file_daemon_proto_msgTypes[65].OneofWrappers = []any{} + file_daemon_proto_msgTypes[56].OneofWrappers = []any{} + file_daemon_proto_msgTypes[67].OneofWrappers = []any{} + file_daemon_proto_msgTypes[73].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 3, - NumMessages: 72, + NumMessages: 80, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8d1080051..bf8553706 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -84,6 +84,15 @@ service DaemonService { rpc Logout(LogoutRequest) returns (LogoutResponse) {} rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} + + // RequestJWTAuth initiates JWT authentication flow for SSH + rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {} + + // WaitJWTToken waits for JWT authentication completion + rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} } @@ -161,6 +170,13 @@ message LoginRequest { // hint is used to pre-fill the email/username field during SSO authentication optional string hint = 33; + + optional bool enableSSHRoot = 34; + optional bool enableSSHSFTP = 35; + optional bool enableSSHLocalPortForwarding = 36; + optional bool enableSSHRemotePortForwarding = 37; + optional bool disableSSHAuth = 38; + optional int32 sshJWTCacheTTL = 39; } message LoginResponse { @@ -188,9 +204,9 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; - bool shouldRunProbes = 2; + bool shouldRunProbes = 2; // the UI do not using this yet, but CLIs could use it to wait until the status is ready - optional bool waitForReady = 3; + optional bool waitForReady = 3; } message StatusResponse{ @@ -255,6 +271,18 @@ message GetConfigResponse { bool disable_server_routes = 19; bool block_lan_access = 20; + + bool enableSSHRoot = 21; + + bool enableSSHSFTP = 24; + + bool enableSSHLocalPortForwarding = 22; + + bool enableSSHRemotePortForwarding = 23; + + bool disableSSHAuth = 25; + + int32 sshJWTCacheTTL = 26; } // PeerState contains the latest state of a peer @@ -276,6 +304,7 @@ message PeerState { repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; + bytes sshHostKey = 19; } // LocalPeerState contains the latest state of the local peer @@ -317,6 +346,20 @@ message NSGroupState { string error = 4; } +// SSHSessionInfo contains information about an active SSH session +message SSHSessionInfo { + string username = 1; + string remoteAddress = 2; + string command = 3; + string jwtUsername = 4; +} + +// SSHServerState contains the latest state of the SSH server +message SSHServerState { + bool enabled = 1; + repeated SSHSessionInfo sessions = 2; +} + // FullStatus contains the full state held by the Status instance message FullStatus { ManagementState managementState = 1; @@ -330,6 +373,7 @@ message FullStatus { repeated SystemEvent events = 7; bool lazyConnectionEnabled = 9; + SSHServerState sshServerState = 10; } // Networks @@ -543,56 +587,63 @@ message SwitchProfileRequest { message SwitchProfileResponse {} message SetConfigRequest { - string username = 1; - string profileName = 2; - // managementUrl to authenticate. - string managementUrl = 3; + string username = 1; + string profileName = 2; + // managementUrl to authenticate. + string managementUrl = 3; - // adminUrl to manage keys. - string adminURL = 4; + // adminUrl to manage keys. + string adminURL = 4; - optional bool rosenpassEnabled = 5; + optional bool rosenpassEnabled = 5; - optional string interfaceName = 6; + optional string interfaceName = 6; - optional int64 wireguardPort = 7; + optional int64 wireguardPort = 7; - optional string optionalPreSharedKey = 8; + optional string optionalPreSharedKey = 8; - optional bool disableAutoConnect = 9; + optional bool disableAutoConnect = 9; - optional bool serverSSHAllowed = 10; + optional bool serverSSHAllowed = 10; - optional bool rosenpassPermissive = 11; + optional bool rosenpassPermissive = 11; - optional bool networkMonitor = 12; + optional bool networkMonitor = 12; - optional bool disable_client_routes = 13; - optional bool disable_server_routes = 14; - optional bool disable_dns = 15; - optional bool disable_firewall = 16; - optional bool block_lan_access = 17; + optional bool disable_client_routes = 13; + optional bool disable_server_routes = 14; + optional bool disable_dns = 15; + optional bool disable_firewall = 16; + optional bool block_lan_access = 17; - optional bool disable_notifications = 18; + optional bool disable_notifications = 18; - optional bool lazyConnectionEnabled = 19; + optional bool lazyConnectionEnabled = 19; - optional bool block_inbound = 20; + optional bool block_inbound = 20; - repeated string natExternalIPs = 21; - bool cleanNATExternalIPs = 22; + repeated string natExternalIPs = 21; + bool cleanNATExternalIPs = 22; - bytes customDNSAddress = 23; + bytes customDNSAddress = 23; - repeated string extraIFaceBlacklist = 24; + repeated string extraIFaceBlacklist = 24; - repeated string dns_labels = 25; - // cleanDNSLabels clean map list of DNS labels. - bool cleanDNSLabels = 26; + repeated string dns_labels = 25; + // cleanDNSLabels clean map list of DNS labels. + bool cleanDNSLabels = 26; - optional google.protobuf.Duration dnsRouteInterval = 27; + optional google.protobuf.Duration dnsRouteInterval = 27; - optional int64 mtu = 28; + optional int64 mtu = 28; + + optional bool enableSSHRoot = 29; + optional bool enableSSHSFTP = 30; + optional bool enableSSHLocalPortForwarding = 31; + optional bool enableSSHRemotePortForwarding = 32; + optional bool disableSSHAuth = 33; + optional int32 sshJWTCacheTTL = 34; } message SetConfigResponse{} @@ -644,3 +695,63 @@ message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; } + +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +message GetPeerSSHHostKeyRequest { + // peer IP address or FQDN to get SSH host key for + string peerAddress = 1; +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +message GetPeerSSHHostKeyResponse { + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + bytes sshHostKey = 1; + // peer IP address + string peerIP = 2; + // peer FQDN + string peerFQDN = 3; + // indicates if the SSH host key was found + bool found = 4; +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +message RequestJWTAuthRequest { + // hint for OIDC login_hint parameter (typically email address) + optional string hint = 1; +} + +// RequestJWTAuthResponse contains authentication flow information +message RequestJWTAuthResponse { + // verification URI for user authentication + string verificationURI = 1; + // complete verification URI (with embedded user code) + string verificationURIComplete = 2; + // user code to enter on verification URI + string userCode = 3; + // device code for polling + string deviceCode = 4; + // expiration time in seconds + int64 expiresIn = 5; + // if a cached token is available, it will be returned here + string cachedToken = 6; + // maximum age of JWT tokens in seconds (from management server) + int64 maxTokenAge = 7; +} + +// WaitJWTTokenRequest for waiting for authentication completion +message WaitJWTTokenRequest { + // device code from RequestJWTAuthResponse + string deviceCode = 1; + // user code for verification + string userCode = 2; +} + +// WaitJWTTokenResponse contains the JWT token after authentication +message WaitJWTTokenResponse { + // JWT token (access token or ID token) + string token = 1; + // token type (e.g., "Bearer") + string tokenType = 2; + // expiration time in seconds + int64 expiresIn = 3; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index bf7c9c7b3..b2bf716b2 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -64,6 +64,12 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) } type daemonServiceClient struct { @@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { + out := new(GetPeerSSHHostKeyResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) { + out := new(RequestJWTAuthResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) { + out := new(WaitJWTTokenResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -399,6 +432,12 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") +} +func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented") +} +func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetPeerSSHHostKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RequestJWTAuthRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/RequestJWTAuth", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(WaitJWTTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/WaitJWTToken", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "GetPeerSSHHostKey", + Handler: _DaemonService_GetPeerSSHHostKey_Handler, + }, + { + MethodName: "RequestJWTAuth", + Handler: _DaemonService_RequestJWTAuth_Handler, + }, + { + MethodName: "WaitJWTToken", + Handler: _DaemonService_WaitJWTToken_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/jwt_cache.go b/client/server/jwt_cache.go new file mode 100644 index 000000000..21e170517 --- /dev/null +++ b/client/server/jwt_cache.go @@ -0,0 +1,79 @@ +package server + +import ( + "sync" + "time" + + "github.com/awnumar/memguard" + log "github.com/sirupsen/logrus" +) + +type jwtCache struct { + mu sync.RWMutex + enclave *memguard.Enclave + expiresAt time.Time + timer *time.Timer + maxTokenSize int +} + +func newJWTCache() *jwtCache { + return &jwtCache{ + maxTokenSize: 8192, + } +} + +func (c *jwtCache) store(token string, maxAge time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cleanup() + + if c.timer != nil { + c.timer.Stop() + } + + tokenBytes := []byte(token) + c.enclave = memguard.NewEnclave(tokenBytes) + + c.expiresAt = time.Now().Add(maxAge) + + var timer *time.Timer + timer = time.AfterFunc(maxAge, func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.timer != timer { + return + } + c.cleanup() + c.timer = nil + log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge) + }) + c.timer = timer +} + +func (c *jwtCache) get() (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.enclave == nil || time.Now().After(c.expiresAt) { + return "", false + } + + buffer, err := c.enclave.Open() + if err != nil { + log.Debugf("Failed to open JWT token enclave: %v", err) + return "", false + } + defer buffer.Destroy() + + token := string(buffer.Bytes()) + return token, true +} + +// cleanup destroys the secure enclave, must be called with lock held +func (c *jwtCache) cleanup() { + if c.enclave != nil { + c.enclave = nil + } + c.expiresAt = time.Time{} +} diff --git a/client/server/network.go b/client/server/network.go index 18b16795d..bb1cce56c 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -11,8 +11,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) type selectRoute struct { diff --git a/client/server/server.go b/client/server/server.go index 6699cdadc..a930e8a02 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -46,6 +46,9 @@ const ( defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + // JWT token cache TTL for the client daemon (disabled by default) + defaultJWTCacheTTL = 0 + errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" @@ -81,6 +84,8 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + + jwtCache *jwtCache } type oauthAuthFlow struct { @@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + jwtCache: newJWTCache(), } } @@ -373,6 +379,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.DisableNotifications = msg.DisableNotifications config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + config.EnableSSHRoot = msg.EnableSSHRoot + config.EnableSSHSFTP = msg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding + if msg.DisableSSHAuth != nil { + config.DisableSSHAuth = msg.DisableSSHAuth + } + if msg.SshJWTCacheTTL != nil { + ttl := int(*msg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + } if msg.Mtu != nil { mtu := uint16(*msg.Mtu) @@ -493,7 +510,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro return nil, err } - if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) { + if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { log.Debugf("using previous oauth flow info") return &proto.LoginResponse{ @@ -510,7 +527,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } } - authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) if err != nil { log.Errorf("getting a request OAuth flow failed: %v", err) return nil, err @@ -1065,12 +1082,235 @@ func (s *Server) Status( fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() + + pbFullStatus.SshServerState = s.getSSHServerState() + statusResponse.FullStatus = pbFullStatus } return &statusResponse, nil } +// getSSHServerState retrieves the current SSH server state including enabled status and active sessions +func (s *Server) getSSHServerState() *proto.SSHServerState { + s.mutex.Lock() + connectClient := s.connectClient + s.mutex.Unlock() + + if connectClient == nil { + return nil + } + + engine := connectClient.Engine() + if engine == nil { + return nil + } + + enabled, sessions := engine.GetSSHServerStatus() + sshServerState := &proto.SSHServerState{ + Enabled: enabled, + } + + for _, session := range sessions { + sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{ + Username: session.Username, + RemoteAddress: session.RemoteAddress, + Command: session.Command, + JwtUsername: session.JWTUsername, + }) + } + + return sshServerState +} + +// GetPeerSSHHostKey retrieves SSH host key for a specific peer +func (s *Server) GetPeerSSHHostKey( + ctx context.Context, + req *proto.GetPeerSSHHostKeyRequest, +) (*proto.GetPeerSSHHostKeyResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + connectClient := s.connectClient + statusRecorder := s.statusRecorder + s.mutex.Unlock() + + if connectClient == nil { + return nil, errors.New("client not initialized") + } + + engine := connectClient.Engine() + if engine == nil { + return nil, errors.New("engine not started") + } + + peerAddress := req.GetPeerAddress() + hostKey, found := engine.GetPeerSSHKey(peerAddress) + + response := &proto.GetPeerSSHHostKeyResponse{ + Found: found, + } + + if !found { + return response, nil + } + + response.SshHostKey = hostKey + + if statusRecorder == nil { + return response, nil + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + response.PeerIP = peerState.IP + response.PeerFQDN = peerState.FQDN + break + } + } + + return response, nil +} + +// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled) +func (s *Server) getJWTCacheTTL() time.Duration { + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil || config.SSHJWTCacheTTL == nil { + return defaultJWTCacheTTL + } + + seconds := *config.SSHJWTCacheTTL + if seconds == 0 { + log.Debug("SSH JWT cache disabled (configured to 0)") + return 0 + } + + ttl := time.Duration(seconds) * time.Second + log.Debugf("SSH JWT cache TTL set to %v from config", ttl) + return ttl +} + +// RequestJWTAuth initiates JWT authentication flow for SSH +func (s *Server) RequestJWTAuth( + ctx context.Context, + msg *proto.RequestJWTAuthRequest, +) (*proto.RequestJWTAuthResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil { + return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured") + } + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + if cachedToken, found := s.jwtCache.get(); found { + log.Debugf("JWT token found in cache, returning cached token for SSH authentication") + + return &proto.RequestJWTAuthResponse{ + CachedToken: cachedToken, + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil + } + } + + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + + if hint == "" { + hint = profilemanager.GetLoginHint() + } + + isDesktop := isUnixRunningDesktop() + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) + } + + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err) + } + + s.mutex.Lock() + s.oauthAuthFlow.flow = oAuthFlow + s.oauthAuthFlow.info = authInfo + s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second) + s.mutex.Unlock() + + return &proto.RequestJWTAuthResponse{ + VerificationURI: authInfo.VerificationURI, + VerificationURIComplete: authInfo.VerificationURIComplete, + UserCode: authInfo.UserCode, + DeviceCode: authInfo.DeviceCode, + ExpiresIn: int64(authInfo.ExpiresIn), + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil +} + +// WaitJWTToken waits for JWT authentication completion +func (s *Server) WaitJWTToken( + ctx context.Context, + req *proto.WaitJWTTokenRequest, +) (*proto.WaitJWTTokenResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + oAuthFlow := s.oauthAuthFlow.flow + authInfo := s.oauthAuthFlow.info + s.mutex.Unlock() + + if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode { + return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow") + } + + tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err) + } + + token := tokenInfo.GetTokenToUse() + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + s.jwtCache.store(token, jwtCacheTTL) + log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL) + } else { + log.Debug("JWT caching disabled, not storing token") + } + + s.mutex.Lock() + s.oauthAuthFlow = oauthAuthFlow{} + s.mutex.Unlock() + return &proto.WaitJWTTokenResponse{ + Token: tokenInfo.GetTokenToUse(), + TokenType: tokenInfo.TokenType, + ExpiresIn: int64(tokenInfo.ExpiresIn), + }, nil +} + +func isUnixRunningDesktop() bool { + if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { + return false + } + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +} + func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return @@ -1136,25 +1376,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p disableServerRoutes := cfg.DisableServerRoutes blockLANAccess := cfg.BlockLANAccess + enableSSHRoot := false + if cfg.EnableSSHRoot != nil { + enableSSHRoot = *cfg.EnableSSHRoot + } + + enableSSHSFTP := false + if cfg.EnableSSHSFTP != nil { + enableSSHSFTP = *cfg.EnableSSHSFTP + } + + enableSSHLocalPortForwarding := false + if cfg.EnableSSHLocalPortForwarding != nil { + enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding + } + + enableSSHRemotePortForwarding := false + if cfg.EnableSSHRemotePortForwarding != nil { + enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding + } + + disableSSHAuth := false + if cfg.DisableSSHAuth != nil { + disableSSHAuth = *cfg.DisableSSHAuth + } + + sshJWTCacheTTL := int32(0) + if cfg.SSHJWTCacheTTL != nil { + sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL) + } + return &proto.GetConfigResponse{ - ManagementUrl: managementURL.String(), - PreSharedKey: preSharedKey, - AdminURL: adminURL.String(), - InterfaceName: cfg.WgIface, - WireguardPort: int64(cfg.WgPort), - Mtu: int64(cfg.MTU), - DisableAutoConnect: cfg.DisableAutoConnect, - ServerSSHAllowed: *cfg.ServerSSHAllowed, - RosenpassEnabled: cfg.RosenpassEnabled, - RosenpassPermissive: cfg.RosenpassPermissive, - LazyConnectionEnabled: cfg.LazyConnectionEnabled, - BlockInbound: cfg.BlockInbound, - DisableNotifications: disableNotifications, - NetworkMonitor: networkMonitor, - DisableDns: disableDNS, - DisableClientRoutes: disableClientRoutes, - DisableServerRoutes: disableServerRoutes, - BlockLanAccess: blockLANAccess, + ManagementUrl: managementURL.String(), + PreSharedKey: preSharedKey, + AdminURL: adminURL.String(), + InterfaceName: cfg.WgIface, + WireguardPort: int64(cfg.WgPort), + Mtu: int64(cfg.MTU), + DisableAutoConnect: cfg.DisableAutoConnect, + ServerSSHAllowed: *cfg.ServerSSHAllowed, + RosenpassEnabled: cfg.RosenpassEnabled, + RosenpassPermissive: cfg.RosenpassPermissive, + LazyConnectionEnabled: cfg.LazyConnectionEnabled, + BlockInbound: cfg.BlockInbound, + DisableNotifications: disableNotifications, + NetworkMonitor: networkMonitor, + DisableDns: disableDNS, + DisableClientRoutes: disableClientRoutes, + DisableServerRoutes: disableServerRoutes, + BlockLanAccess: blockLANAccess, + EnableSSHRoot: enableSSHRoot, + EnableSSHSFTP: enableSSHSFTP, + EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding, + DisableSSHAuth: disableSSHAuth, + SshJWTCacheTTL: sshJWTCacheTTL, }, nil } @@ -1385,6 +1661,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { RosenpassEnabled: peerState.RosenpassEnabled, Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), + SshHostKey: peerState.SSHHostKey, } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) } diff --git a/client/server/server_test.go b/client/server/server_test.go index ae5f759ee..96d4c0af0 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -14,6 +14,7 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -316,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 1260bcc78..8e360175d 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { lazyConnectionEnabled := true blockInbound := true mtu := int64(1280) + sshJWTCacheTTL := int32(300) req := &proto.SetConfigRequest{ ProfileName: profName, @@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { CleanDNSLabels: false, DnsRouteInterval: durationpb.New(2 * time.Minute), Mtu: &mtu, + SshJWTCacheTTL: &sshJWTCacheTTL, } _, err = s.SetConfig(ctx, req) @@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) require.Equal(t, uint16(mtu), cfg.MTU) + require.NotNil(t, cfg.SSHJWTCacheTTL) + require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL) verifyAllFieldsCovered(t, req) } @@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { } expectedFields := map[string]bool{ - "ManagementUrl": true, - "AdminURL": true, - "RosenpassEnabled": true, - "RosenpassPermissive": true, - "ServerSSHAllowed": true, - "InterfaceName": true, - "WireguardPort": true, - "OptionalPreSharedKey": true, - "DisableAutoConnect": true, - "NetworkMonitor": true, - "DisableClientRoutes": true, - "DisableServerRoutes": true, - "DisableDns": true, - "DisableFirewall": true, - "BlockLanAccess": true, - "DisableNotifications": true, - "LazyConnectionEnabled": true, - "BlockInbound": true, - "NatExternalIPs": true, - "CustomDNSAddress": true, - "ExtraIFaceBlacklist": true, - "DnsLabels": true, - "DnsRouteInterval": true, - "Mtu": true, + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + "EnableSSHRoot": true, + "EnableSSHSFTP": true, + "EnableSSHLocalPortForwarding": true, + "EnableSSHRemotePortForwarding": true, + "DisableSSHAuth": true, + "SshJWTCacheTTL": true, } val := reflect.ValueOf(req).Elem() @@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { // Map of CLI flag names to their corresponding SetConfigRequest field names. // This map must be updated when adding new config-related CLI flags. flagToField := map[string]string{ - "management-url": "ManagementUrl", - "admin-url": "AdminURL", - "enable-rosenpass": "RosenpassEnabled", - "rosenpass-permissive": "RosenpassPermissive", - "allow-server-ssh": "ServerSSHAllowed", - "interface-name": "InterfaceName", - "wireguard-port": "WireguardPort", - "preshared-key": "OptionalPreSharedKey", - "disable-auto-connect": "DisableAutoConnect", - "network-monitor": "NetworkMonitor", - "disable-client-routes": "DisableClientRoutes", - "disable-server-routes": "DisableServerRoutes", - "disable-dns": "DisableDns", - "disable-firewall": "DisableFirewall", - "block-lan-access": "BlockLanAccess", - "block-inbound": "BlockInbound", - "enable-lazy-connection": "LazyConnectionEnabled", - "external-ip-map": "NatExternalIPs", - "dns-resolver-address": "CustomDNSAddress", - "extra-iface-blacklist": "ExtraIFaceBlacklist", - "extra-dns-labels": "DnsLabels", - "dns-router-interval": "DnsRouteInterval", - "mtu": "Mtu", + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + "enable-ssh-root": "EnableSSHRoot", + "enable-ssh-sftp": "EnableSSHSFTP", + "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding", + "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding", + "disable-ssh-auth": "DisableSSHAuth", + "ssh-jwt-cache-ttl": "SshJWTCacheTTL", } // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). diff --git a/client/server/state_generic.go b/client/server/state_generic.go index e6c7bdd44..980ba0cda 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -6,9 +6,11 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 087628907..019477d8e 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { @@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&nftables.ShutdownState{}) mgr.RegisterState(&iptables.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/ssh/client.go b/client/ssh/client.go deleted file mode 100644 index afba347f8..000000000 --- a/client/ssh/client.go +++ /dev/null @@ -1,118 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "os" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/term" -) - -// Client wraps crypto/ssh Client to simplify usage -type Client struct { - client *ssh.Client -} - -// Close closes the wrapped SSH Client -func (c *Client) Close() error { - return c.client.Close() -} - -// OpenTerminal starts an interactive terminal session with the remote SSH server -func (c *Client) OpenTerminal() error { - session, err := c.client.NewSession() - if err != nil { - return fmt.Errorf("failed to open new session: %v", err) - } - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - fd := int(os.Stdout.Fd()) - state, err := term.MakeRaw(fd) - if err != nil { - return fmt.Errorf("failed to run raw terminal: %s", err) - } - defer func() { - err := term.Restore(fd, state) - if err != nil { - return - } - }() - - w, h, err := term.GetSize(fd) - if err != nil { - return fmt.Errorf("terminal get size: %s", err) - } - - modes := ssh.TerminalModes{ - ssh.ECHO: 1, - ssh.TTY_OP_ISPEED: 14400, - ssh.TTY_OP_OSPEED: 14400, - } - - terminal := os.Getenv("TERM") - if terminal == "" { - terminal = "xterm-256color" - } - if err := session.RequestPty(terminal, h, w, modes); err != nil { - return fmt.Errorf("failed requesting pty session with xterm: %s", err) - } - - session.Stdout = os.Stdout - session.Stderr = os.Stderr - session.Stdin = os.Stdin - - if err := session.Shell(); err != nil { - return fmt.Errorf("failed to start login shell on the remote host: %s", err) - } - - if err := session.Wait(); err != nil { - if e, ok := err.(*ssh.ExitError); ok { - if e.ExitStatus() == 130 { - return nil - } - } - return fmt.Errorf("failed running SSH session: %s", err) - } - - return nil -} - -// DialWithKey connects to the remote SSH server with a provided private key file (PEM). -func DialWithKey(addr, user string, privateKey []byte) (*Client, error) { - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - config := &ssh.ClientConfig{ - User: user, - Timeout: 5 * time.Second, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), - } - - return Dial("tcp", addr, config) -} - -// Dial connects to the remote SSH server. -func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) { - client, err := ssh.Dial(network, addr, config) - if err != nil { - return nil, err - } - return &Client{ - client: client, - }, nil -} diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go new file mode 100644 index 000000000..882056374 --- /dev/null +++ b/client/ssh/client/client.go @@ -0,0 +1,699 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" + "golang.org/x/term" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" +) + +const ( + // DefaultDaemonAddr is the default address for the NetBird daemon + DefaultDaemonAddr = "unix:///var/run/netbird.sock" + // DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows + DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731" +) + +// Client wraps crypto/ssh Client for simplified SSH operations +type Client struct { + client *ssh.Client + terminalState *term.State + terminalFd int + + windowsStdoutMode uint32 // nolint:unused + windowsStdinMode uint32 // nolint:unused +} + +func (c *Client) Close() error { + return c.client.Close() +} + +func (c *Client) OpenTerminal(ctx context.Context) error { + session, err := c.client.NewSession() + if err != nil { + return fmt.Errorf("new session: %w", err) + } + defer func() { + if err := session.Close(); err != nil { + log.Debugf("session close error: %v", err) + } + }() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return err + } + + c.setupSessionIO(session) + + if err := session.Shell(); err != nil { + return fmt.Errorf("start shell: %w", err) + } + + return c.waitForSession(ctx, session) +} + +// setupSessionIO connects session streams to local terminal +func (c *Client) setupSessionIO(session *ssh.Session) { + session.Stdout = os.Stdout + session.Stderr = os.Stderr + session.Stdin = os.Stdin +} + +// waitForSession waits for the session to complete with context cancellation +func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error { + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + defer c.restoreTerminal() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return c.handleSessionError(err) + } +} + +// handleSessionError processes session termination errors +func (c *Client) handleSessionError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return fmt.Errorf("session wait: %w", err) + } + + return nil +} + +// restoreTerminal restores the terminal to its original state +func (c *Client) restoreTerminal() { + if c.terminalState != nil { + _ = term.Restore(c.terminalFd, c.terminalState) + c.terminalState = nil + c.terminalFd = 0 + } + + if err := c.restoreWindowsConsoleState(); err != nil { + log.Debugf("restore Windows console state: %v", err) + } +} + +// ExecuteCommand executes a command on the remote host and returns the output +func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return nil, err + } + defer cleanup() + + output, err := session.CombinedOutput(command) + if err != nil { + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return output, fmt.Errorf("execute command: %w", err) + } + } + + return output, nil +} + +// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal +func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions +func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return fmt.Errorf("setup terminal mode: %w", err) + } + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + defer c.restoreTerminal() + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// handleCommandError processes command execution errors +func (c *Client) handleCommandError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if errors.As(err, &e) || errors.As(err, &em) { + return err + } + + return fmt.Errorf("execute command: %w", err) +} + +// setupContextCancellation sets up context cancellation for a session +func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + _ = session.Close() + case <-done: + } + }() + return func() { close(done) } +} + +// createSession creates a new SSH session with context cancellation setup +func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) { + session, err := c.client.NewSession() + if err != nil { + return nil, nil, fmt.Errorf("new session: %w", err) + } + + cancel := c.setupContextCancellation(ctx, session) + cleanup := func() { + cancel() + _ = session.Close() + } + + return session, cleanup, nil +} + +// getDefaultDaemonAddr returns the daemon address from environment or default for the OS +func getDefaultDaemonAddr() string { + if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" { + return addr + } + if runtime.GOOS == "windows" { + return DefaultDaemonAddrWindows + } + return DefaultDaemonAddr +} + +// DialOptions contains options for SSH connections +type DialOptions struct { + KnownHostsFile string + IdentityFile string + DaemonAddr string + SkipCachedToken bool + InsecureSkipVerify bool +} + +// Dial connects to the given ssh server with specified options +func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) { + daemonAddr := opts.DaemonAddr + if daemonAddr == "" { + daemonAddr = getDefaultDaemonAddr() + } + opts.DaemonAddr = daemonAddr + + hostKeyCallback, err := createHostKeyCallback(opts) + if err != nil { + return nil, fmt.Errorf("create host key callback: %w", err) + } + + config := &ssh.ClientConfig{ + User: user, + Timeout: 30 * time.Second, + HostKeyCallback: hostKeyCallback, + } + + if opts.IdentityFile != "" { + authMethod, err := createSSHKeyAuth(opts.IdentityFile) + if err != nil { + return nil, fmt.Errorf("create SSH key auth: %w", err) + } + config.Auth = append(config.Auth, authMethod) + } + + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) +} + +// dialSSH establishes an SSH connection without JWT authentication +func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("connection close after handshake failure: %v", closeErr) + } + return nil, fmt.Errorf("ssh handshake: %w", err) + } + + client := ssh.NewClient(clientConn, chans, reqs) + return &Client{ + client: client, + }, nil +} + +// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("parse address %s: %w", addr, err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("parse port %s: %w", portStr, err) + } + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + if err != nil { + return nil, fmt.Errorf("SSH server detection failed: %w", err) + } + + if !serverType.RequiresJWT() { + return dialSSH(ctx, network, addr, config) + } + + jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) + if err != nil { + return nil, fmt.Errorf("request JWT token: %w", err) + } + + configWithJWT := nbssh.AddJWTAuth(config, jwtToken) + return dialSSH(ctx, network, addr, configWithJWT) +} + +// requestJWTToken requests a JWT token from the NetBird daemon +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { + hint := profilemanager.GetLoginHint() + + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return "", fmt.Errorf("connect to daemon: %w", err) + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) +} + +// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon +func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("daemon connection close error: %v", err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + verifier := nbssh.NewDaemonHostKeyVerifier(client) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} + +func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) { + addr := strings.TrimPrefix(daemonAddr, "tcp://") + + conn, err := grpc.NewClient( + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err) + return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err) + } + + return conn, nil +} + +// getKnownHostsFiles returns paths to known_hosts files in order of preference +func getKnownHostsFiles() []string { + var files []string + + // User's known_hosts file (highest priority) + if homeDir, err := os.UserHomeDir(); err == nil { + userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts") + files = append(files, userKnownHosts) + } + + // NetBird managed known_hosts files + if runtime.GOOS == "windows" { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird") + files = append(files, netbirdKnownHosts) + } else { + files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird") + files = append(files, "/etc/ssh/ssh_known_hosts") + } + + return files +} + +// createHostKeyCallback creates a host key verification callback +func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) { + if opts.InsecureSkipVerify { + return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil { + return nil + } + return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile) + }, nil +} + +func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + if daemonAddr == "" { + return fmt.Errorf("no daemon address") + } + return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr) +} + +func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error { + knownHostsFiles := getKnownHostsFilesList(knownHostsFile) + hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles) + + for _, callback := range hostKeyCallbacks { + if err := callback(hostname, remote, key); err == nil { + return nil + } + } + return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname) +} + +func getKnownHostsFilesList(knownHostsFile string) []string { + if knownHostsFile != "" { + return []string{knownHostsFile} + } + return getKnownHostsFiles() +} + +func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback { + var hostKeyCallbacks []ssh.HostKeyCallback + for _, file := range knownHostsFiles { + if callback, err := knownhosts.New(file); err == nil { + hostKeyCallbacks = append(hostKeyCallbacks, callback) + } + } + return hostKeyCallbacks +} + +// createSSHKeyAuth creates SSH key authentication from a private key file +func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) { + keyData, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err) + } + + signer, err := ssh.ParsePrivateKey(keyData) + if err != nil { + return nil, fmt.Errorf("parse SSH private key: %w", err) + } + + return ssh.PublicKeys(signer), nil +} + +// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr +func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error { + localListener, err := net.Listen("tcp", localAddr) + if err != nil { + return fmt.Errorf("listen on %s: %w", localAddr, err) + } + + go func() { + defer func() { + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + }() + for { + localConn, err := localListener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + continue + } + + go c.handleLocalForward(localConn, remoteAddr) + } + }() + + <-ctx.Done() + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + return ctx.Err() +} + +// handleLocalForward handles a single local port forwarding connection +func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + channel, err := c.client.Dial("tcp", remoteAddr) + if err != nil { + if strings.Contains(err.Error(), "administratively prohibited") { + _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") + } else { + log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) + } + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("local forward copy error (local->remote): %v", err) + } + }() + + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("local forward copy error (remote->local): %v", err) + } +} + +// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr +func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error { + host, port, err := c.parseRemoteAddress(remoteAddr) + if err != nil { + return fmt.Errorf("parse remote address: %w", err) + } + + req := c.buildTCPIPForwardRequest(host, port) + if err := c.sendTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("setup remote forward: %w", err) + } + + go c.handleRemoteForwardChannels(ctx, localAddr) + + <-ctx.Done() + + if err := c.cancelTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("cancel tcpip-forward: %w", err) + } + return ctx.Err() +} + +// parseRemoteAddress parses host and port from remote address string +func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) { + host, portStr, err := net.SplitHostPort(remoteAddr) + if err != nil { + return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err) + } + + return host, uint32(port), nil +} + +// buildTCPIPForwardRequest creates a tcpip-forward request message +func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg { + return tcpipForwardMsg{ + Host: host, + Port: port, + } +} + +// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding +func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error { + ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send tcpip-forward request: %w", err) + } + if !ok { + return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") + } + return nil +} + +// cancelTCPIPForwardRequest cancels the tcpip-forward request +func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error { + _, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send cancel-tcpip-forward request: %w", err) + } + return nil +} + +// handleRemoteForwardChannels handles incoming forwarded-tcpip channels +func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) { + // Get the channel once - subsequent calls return nil! + channelRequests := c.client.HandleChannelOpen("forwarded-tcpip") + if channelRequests == nil { + log.Debugf("forwarded-tcpip channel type already being handled") + return + } + + for { + select { + case <-ctx.Done(): + return + case newChan := <-channelRequests: + if newChan != nil { + go c.handleRemoteForwardChannel(newChan, localAddr) + } + } + } +} + +// handleRemoteForwardChannel handles a single forwarded-tcpip channel +func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { + channel, reqs, err := newChan.Accept() + if err != nil { + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go ssh.DiscardRequests(reqs) + + localConn, err := net.Dial("tcp", localAddr) + if err != nil { + return + } + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("remote forward copy error (remote->local): %v", err) + } + }() + + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("remote forward copy error (local->remote): %v", err) + } +} + +// tcpipForwardMsg represents the structure for tcpip-forward requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go new file mode 100644 index 000000000..e38e02a86 --- /dev/null +++ b/client/ssh/client/client_test.go @@ -0,0 +1,512 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHClient_DialWithKey(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create and start server + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Test Dial + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + // Verify client is connected + assert.NotNil(t, client.client) +} + +func TestSSHClient_CommandExecution(t *testing.T) { + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues") + } + + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + t.Run("ExecuteCommand captures output", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo hello") + assert.NoError(t, err) + assert.Contains(t, string(output), "hello") + }) + + t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) { + err := client.ExecuteCommandWithIO(ctx, "echo world") + assert.NoError(t, err) + }) + + t.Run("commands with flags work", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo -n test_flag") + assert.NoError(t, err) + assert.Equal(t, "test_flag", strings.TrimSpace(string(output))) + }) + + t.Run("non-zero exit codes don't return errors", func(t *testing.T) { + var testCmd string + if runtime.GOOS == "windows" { + testCmd = "echo hello | Select-String notfound" + } else { + testCmd = "echo 'hello' | grep 'notfound'" + } + _, err := client.ExecuteCommand(ctx, testCmd) + assert.NoError(t, err) + }) +} + +func TestSSHClient_ConnectionHandling(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate client key for multiple connections + + const numClients = 3 + clients := make([]*Client, numClients) + + currentUser := testutil.GetTestUsername(t) + for i := 0; i < numClients; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + cancel() + require.NoError(t, err, "Client %d should connect successfully", i) + clients[i] = client + } + + for i, client := range clients { + err := client.Close() + assert.NoError(t, err, "Client %d should close without error", i) + } +} + +func TestSSHClient_ContextCancellation(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + t.Run("connection with short timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + _, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + // Check for actual timeout-related errors rather than string matching + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "timeout"), + "Expected timeout-related error, got: %v", err) + } + }) + + t.Run("command execution cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cmdCancel() + + err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10") + if err != nil { + var exitMissingErr *cryptossh.ExitMissingError + isValidCancellation := errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + errors.As(err, &exitMissingErr) + assert.True(t, isValidCancellation, "Should handle command cancellation properly") + } + }) +} + +func TestSSHClient_NoAuthMode(t *testing.T) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + + t.Run("any key succeeds in no-auth mode", func(t *testing.T) { + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + assert.NoError(t, err) + if client != nil { + require.NoError(t, client.Close(), "Client should close without error") + } + }) +} + +func TestSSHClient_TerminalState(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + assert.Nil(t, client.terminalState) + assert.Equal(t, 0, client.terminalFd) + + client.restoreTerminal() + assert.Nil(t, client.terminalState) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := client.OpenTerminal(ctx) + // In test environment without a real terminal, this may complete quickly or timeout + // Both behaviors are acceptable for testing terminal state management + if err != nil { + if runtime.GOOS == "windows" { + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "console"), + "Should timeout or have console error on Windows") + } else { + // On Unix systems in test environment, we may get various errors + // including timeouts or terminal-related errors + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "terminal") || + strings.Contains(err.Error(), "pty"), + "Expected timeout or terminal-related error, got: %v", err) + } + } +} + +func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + + return server, serverAddr, client +} + +func TestSSHClient_PortForwarding(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + t.Run("local forwarding times out gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "connection"), + "Expected context or connection error") + }) + + t.Run("remote forwarding denied", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + strings.Contains(err.Error(), "denied") || + strings.Contains(err.Error(), "disabled"), + "Should be denied by default") + }) + + t.Run("invalid addresses fail", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080") + assert.Error(t, err) + + err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address") + assert.Error(t, err) + }) +} + +func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { + if testing.Short() { + t.Skip("Skipping data transfer test in short mode") + } + + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowLocalPortForwarding(true) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Port forwarding requires the actual current user, not test user + realUser, err := getRealCurrentUser() + require.NoError(t, err) + + // Skip if running as system account that can't do port forwarding + if testutil.IsSystemAccount(realUser) { + t.Skipf("Skipping port forwarding test - running as system account: %s", realUser) + } + + client, err := Dial(ctx, serverAddr, realUser, DialOptions{ + InsecureSkipVerify: true, // Skip host key verification for test + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + if err := testServer.Close(); err != nil { + t.Logf("test server close error: %v", err) + } + }() + + testServerAddr := testServer.Addr().String() + expectedResponse := "Hello, World!" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("connection read error: %v", err) + return + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("connection write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + if err := localListener.Close(); err != nil { + t.Logf("local listener close error: %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + err := client.LocalPortForward(ctx, localAddr, testServerAddr) + if err != nil && !errors.Is(err, context.Canceled) { + if isWindowsPrivilegeError(err) { + t.Logf("Port forward failed due to Windows privilege restrictions: %v", err) + } else { + t.Logf("Port forward error: %v", err) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second) + require.NoError(t, err) + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + _, err = conn.Write([]byte("test")) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Logf("set read deadline error: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + require.NoError(t, err) + assert.Equal(t, len(expectedResponse), n) + assert.Equal(t, expectedResponse, string(response)) +} + +// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding +func getRealCurrentUser() (string, error) { + if runtime.GOOS == "windows" { + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + } + + if username := os.Getenv("USER"); username != "" { + return username, nil + } + + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + + return "", fmt.Errorf("unable to determine current user") +} + +// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions +func isWindowsPrivilegeError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD + strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess) + strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser) + strings.Contains(errStr, "privilege") || + strings.Contains(errStr, "access denied") || + strings.Contains(errStr, "user authentication failed") +} diff --git a/client/ssh/client/terminal_unix.go b/client/ssh/client/terminal_unix.go new file mode 100644 index 000000000..aaa3418f9 --- /dev/null +++ b/client/ssh/client/terminal_unix.go @@ -0,0 +1,127 @@ +//go:build !windows + +package client + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/term" +) + +func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error { + stdinFd := int(os.Stdin.Fd()) + + if !term.IsTerminal(stdinFd) { + return c.setupNonTerminalMode(ctx, session) + } + + fd := int(os.Stdin.Fd()) + + state, err := term.MakeRaw(fd) + if err != nil { + return c.setupNonTerminalMode(ctx, session) + } + + if err := c.setupTerminal(session, fd); err != nil { + if restoreErr := term.Restore(fd, state); restoreErr != nil { + log.Debugf("restore terminal state: %v", restoreErr) + } + return err + } + + c.terminalState = state + c.terminalFd = fd + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + go func() { + defer signal.Stop(sigChan) + select { + case <-ctx.Done(): + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + case sig := <-sigChan: + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + signal.Reset(sig) + s, ok := sig.(syscall.Signal) + if !ok { + log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig) + return + } + if err := syscall.Kill(syscall.Getpid(), s); err != nil { + log.Debugf("kill process with signal %v: %v", s, err) + } + } + }() + + return nil +} + +func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error { + return nil +} + +// restoreWindowsConsoleState is a no-op on Unix systems +func (c *Client) restoreWindowsConsoleState() error { + return nil +} + +func (c *Client) setupTerminal(session *ssh.Session, fd int) error { + w, h, err := term.GetSize(fd) + if err != nil { + return fmt.Errorf("get terminal size: %w", err) + } + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + // Ctrl+C + ssh.VINTR: 3, + // Ctrl+\ + ssh.VQUIT: 28, + // Backspace + ssh.VERASE: 127, + // Ctrl+U + ssh.VKILL: 21, + // Ctrl+D + ssh.VEOF: 4, + ssh.VEOL: 0, + ssh.VEOL2: 0, + // Ctrl+Q + ssh.VSTART: 17, + // Ctrl+S + ssh.VSTOP: 19, + // Ctrl+Z + ssh.VSUSP: 26, + // Ctrl+O + ssh.VDISCARD: 15, + // Ctrl+R + ssh.VREPRINT: 18, + // Ctrl+W + ssh.VWERASE: 23, + // Ctrl+V + ssh.VLNEXT: 22, + } + + terminal := os.Getenv("TERM") + if terminal == "" { + terminal = "xterm-256color" + } + + if err := session.RequestPty(terminal, h, w, modes); err != nil { + return fmt.Errorf("request pty: %w", err) + } + + return nil +} diff --git a/client/ssh/client/terminal_windows.go b/client/ssh/client/terminal_windows.go new file mode 100644 index 000000000..462438317 --- /dev/null +++ b/client/ssh/client/terminal_windows.go @@ -0,0 +1,265 @@ +package client + +import ( + "context" + "errors" + "fmt" + "os" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +const ( + enableProcessedInput = 0x0001 + enableLineInput = 0x0002 + enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT + enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode) + enableVirtualTerminalInput = 0x0200 +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +// ConsoleUnavailableError indicates that Windows console handles are not available +// (e.g., in CI environments where stdout/stdin are redirected) +type ConsoleUnavailableError struct { + Operation string + Err error +} + +func (e *ConsoleUnavailableError) Error() string { + return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err) +} + +func (e *ConsoleUnavailableError) Unwrap() error { + return e.Err +} + +type coord struct { + x, y int16 +} + +type smallRect struct { + left, top, right, bottom int16 +} + +type consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes uint16 + window smallRect + maximumWindowSize coord +} + +func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error { + if err := c.saveWindowsConsoleState(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("console unavailable, not requesting PTY: %v", err) + return nil + } + return fmt.Errorf("save console state: %w", err) + } + + if err := c.enableWindowsVirtualTerminal(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("virtual terminal unavailable: %v", err) + } else { + return fmt.Errorf("failed to enable virtual terminal: %w", err) + } + } + + w, h := c.getWindowsConsoleSize() + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.ICRNL: 1, + ssh.OPOST: 1, + ssh.ONLCR: 1, + ssh.ISIG: 1, + ssh.ICANON: 1, + ssh.VINTR: 3, // Ctrl+C + ssh.VQUIT: 28, // Ctrl+\ + ssh.VERASE: 127, // Backspace + ssh.VKILL: 21, // Ctrl+U + ssh.VEOF: 4, // Ctrl+D + ssh.VEOL: 0, + ssh.VEOL2: 0, + ssh.VSTART: 17, // Ctrl+Q + ssh.VSTOP: 19, // Ctrl+S + ssh.VSUSP: 26, // Ctrl+Z + ssh.VDISCARD: 15, // Ctrl+O + ssh.VWERASE: 23, // Ctrl+W + ssh.VLNEXT: 22, // Ctrl+V + ssh.VREPRINT: 18, // Ctrl+R + } + + if err := session.RequestPty("xterm-256color", h, w, modes); err != nil { + if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil { + log.Debugf("restore Windows console state: %v", restoreErr) + } + return fmt.Errorf("request pty: %w", err) + } + + return nil +} + +func (c *Client) saveWindowsConsoleState() error { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in saveWindowsConsoleState: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + var stdoutMode, stdinMode uint32 + + ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode))) + if ret == 0 { + log.Debugf("failed to get stdout console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdout console mode", + Err: err, + } + } + + ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode))) + if ret == 0 { + log.Debugf("failed to get stdin console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdin console mode", + Err: err, + } + } + + c.terminalFd = 1 + c.windowsStdoutMode = stdoutMode + c.windowsStdinMode = stdinMode + + log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode) + return nil +} + +func (c *Client) enableWindowsVirtualTerminal() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + var mode uint32 + + ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdout console mode for VT", + Err: winErr, + } + } + + mode |= enableVirtualTerminalProcessing + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "enable virtual terminal processing", + Err: winErr, + } + } + + ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdin console mode for VT", + Err: winErr, + } + } + + mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput) + mode |= enableVirtualTerminalInput + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "set stdin raw mode", + Err: winErr, + } + } + + log.Debugf("enabled Windows virtual terminal processing") + return nil +} + +func (c *Client) getWindowsConsoleSize() (int, int) { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in getWindowsConsoleSize: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + var csbi consoleScreenBufferInfo + + ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi))) + if ret == 0 { + log.Debugf("failed to get console buffer info, using defaults: %v", err) + return 80, 24 + } + + width := int(csbi.window.right - csbi.window.left + 1) + height := int(csbi.window.bottom - csbi.window.top + 1) + + log.Debugf("Windows console size: %dx%d", width, height) + return width, height +} + +func (c *Client) restoreWindowsConsoleState() error { + var err error + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r) + } + }() + + if c.terminalFd != 1 { + return nil + } + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode)) + if ret == 0 { + log.Debugf("failed to restore stdout console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdout console mode: %w", winErr) + } + } + + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode)) + if ret == 0 { + log.Debugf("failed to restore stdin console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdin console mode: %w", winErr) + } + } + + c.terminalFd = 0 + c.windowsStdoutMode = 0 + c.windowsStdinMode = 0 + + log.Debugf("restored Windows console state") + return err +} diff --git a/client/ssh/common.go b/client/ssh/common.go new file mode 100644 index 000000000..3beb12806 --- /dev/null +++ b/client/ssh/common.go @@ -0,0 +1,171 @@ +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/proto" +) + +const ( + NetBirdSSHConfigFile = "99-netbird.conf" + + UnixSSHConfigDir = "/etc/ssh/ssh_config.d" + WindowsSSHConfigDir = "ssh/ssh_config.d" +) + +var ( + // ErrPeerNotFound indicates the peer was not found in the network + ErrPeerNotFound = errors.New("peer not found in network") + // ErrNoStoredKey indicates the peer has no stored SSH host key + ErrNoStoredKey = errors.New("peer has no stored SSH host key") +) + +// HostKeyVerifier provides SSH host key verification +type HostKeyVerifier interface { + VerifySSHHostKey(peerAddress string, key []byte) error +} + +// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon +type DaemonHostKeyVerifier struct { + client proto.DaemonServiceClient +} + +// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier +func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier { + return &DaemonHostKeyVerifier{ + client: client, + } +} + +// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon +func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{ + PeerAddress: peerAddress, + }) + if err != nil { + return err + } + + if !response.GetFound() { + return ErrPeerNotFound + } + + storedKeyData := response.GetSshHostKey() + + return VerifyHostKey(storedKeyData, presentedKey, peerAddress) +} + +// RequestJWTToken requests or retrieves a JWT token for SSH authentication +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { + req := &proto.RequestJWTAuthRequest{} + if hint != "" { + req.Hint = &hint + } + authResponse, err := client.RequestJWTAuth(ctx, req) + if err != nil { + return "", fmt.Errorf("request JWT auth: %w", err) + } + + if useCache && authResponse.CachedToken != "" { + log.Debug("Using cached authentication token") + return authResponse.CachedToken, nil + } + + if stderr != nil { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + _, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") + } + + tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ + DeviceCode: authResponse.DeviceCode, + UserCode: authResponse.UserCode, + }) + if err != nil { + return "", fmt.Errorf("wait for JWT token: %w", err) + } + + if stdout != nil { + _, _ = fmt.Fprintln(stdout, "Authentication successful!") + } + return tokenResponse.Token, nil +} + +// VerifyHostKey verifies an SSH host key against stored peer key data. +// Returns nil only if the presented key matches the stored key. +// Returns ErrNoStoredKey if storedKeyData is empty. +// Returns an error if the keys don't match or if parsing fails. +func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error { + if len(storedKeyData) == 0 { + return ErrNoStoredKey + } + + storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData) + if err != nil { + return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err) + } + + if !bytes.Equal(presentedKey, storedPubKey.Marshal()) { + return fmt.Errorf("SSH host key mismatch for %s", peerAddress) + } + + return nil +} + +// AddJWTAuth prepends JWT password authentication to existing auth methods. +// This ensures JWT auth is tried first while preserving any existing auth methods. +func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig { + configWithJWT := *config + configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...) + return &configWithJWT +} + +// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier. +// It tries multiple addresses (hostname, IP) for the peer before failing. +func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + addresses := buildAddressList(hostname, remote) + presentedKey := key.Marshal() + + for _, addr := range addresses { + if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil { + if errors.Is(err, ErrPeerNotFound) { + // Try other addresses for this peer + continue + } + return err + } + // Verified + return nil + } + + return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname) + } +} + +// buildAddressList creates a list of addresses to check for host key verification. +// It includes the original hostname and extracts the host part from the remote address if different. +func buildAddressList(hostname string, remote net.Addr) []string { + addresses := []string{hostname} + if host, _, err := net.SplitHostPort(remote.String()); err == nil { + if host != hostname { + addresses = append(addresses, host) + } + } + return addresses +} diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go new file mode 100644 index 000000000..03a136de3 --- /dev/null +++ b/client/ssh/config/manager.go @@ -0,0 +1,282 @@ +package config + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +const ( + EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG" + + EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG" + + MaxPeersForSSHConfig = 200 + + fileWriteTimeout = 2 * time.Second +) + +func isSSHConfigDisabled() bool { + value := os.Getenv(EnvDisableSSHConfig) + if value == "" { + return false + } + + disabled, err := strconv.ParseBool(value) + if err != nil { + return true + } + return disabled +} + +func isSSHConfigForced() bool { + value := os.Getenv(EnvForceSSHConfig) + if value == "" { + return false + } + + forced, err := strconv.ParseBool(value) + if err != nil { + return true + } + return forced +} + +// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count +func shouldGenerateSSHConfig(peerCount int) bool { + if isSSHConfigDisabled() { + return false + } + + if isSSHConfigForced() { + return true + } + + return peerCount <= MaxPeersForSSHConfig +} + +// writeFileWithTimeout writes data to a file with a timeout +func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error { + ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- os.WriteFile(filename, data, perm) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename) + } +} + +// Manager handles SSH client configuration for NetBird peers +type Manager struct { + sshConfigDir string + sshConfigFile string +} + +// PeerSSHInfo represents a peer's SSH configuration information +type PeerSSHInfo struct { + Hostname string + IP string + FQDN string +} + +// New creates a new SSH config manager +func New() *Manager { + sshConfigDir := getSystemSSHConfigDir() + return &Manager{ + sshConfigDir: sshConfigDir, + sshConfigFile: nbssh.NetBirdSSHConfigFile, + } +} + +// getSystemSSHConfigDir returns platform-specific SSH configuration directory +func getSystemSSHConfigDir() string { + if runtime.GOOS == "windows" { + return getWindowsSSHConfigDir() + } + return nbssh.UnixSSHConfigDir +} + +func getWindowsSSHConfigDir() string { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + return filepath.Join(programData, nbssh.WindowsSSHConfigDir) +} + +// SetupSSHClientConfig creates SSH client configuration for NetBird peers +func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error { + if !shouldGenerateSSHConfig(len(peers)) { + m.logSkipReason(len(peers)) + return nil + } + + sshConfig, err := m.buildSSHConfig(peers) + if err != nil { + return fmt.Errorf("build SSH config: %w", err) + } + return m.writeSSHConfig(sshConfig) +} + +func (m *Manager) logSkipReason(peerCount int) { + if isSSHConfigDisabled() { + log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig) + } else { + log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.", + peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig) + } +} + +func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) { + sshConfig := m.buildConfigHeader() + + var allHostPatterns []string + for _, peer := range peers { + hostPatterns := m.buildHostPatterns(peer) + allHostPatterns = append(allHostPatterns, hostPatterns...) + } + + if len(allHostPatterns) > 0 { + peerConfig, err := m.buildPeerConfig(allHostPatterns) + if err != nil { + return "", err + } + sshConfig += peerConfig + } + + return sshConfig, nil +} + +func (m *Manager) buildConfigHeader() string { + return "# NetBird SSH client configuration\n" + + "# Generated automatically - do not edit manually\n" + + "#\n" + + "# To disable SSH config management, use:\n" + + "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" + + "#\n\n" +} + +func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { + uniquePatterns := make(map[string]bool) + var deduplicatedPatterns []string + for _, pattern := range allHostPatterns { + if !uniquePatterns[pattern] { + uniquePatterns[pattern] = true + deduplicatedPatterns = append(deduplicatedPatterns, pattern) + } + } + + execPath, err := m.getNetBirdExecutablePath() + if err != nil { + return "", fmt.Errorf("get NetBird executable path: %w", err) + } + + hostLine := strings.Join(deduplicatedPatterns, " ") + config := fmt.Sprintf("Host %s\n", hostLine) + + if runtime.GOOS == "windows" { + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) + } else { + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath) + } + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" + + if runtime.GOOS == "windows" { + config += " UserKnownHostsFile NUL\n" + } else { + config += " UserKnownHostsFile /dev/null\n" + } + + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" + + return config, nil +} + +func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { + var hostPatterns []string + if peer.IP != "" { + hostPatterns = append(hostPatterns, peer.IP) + } + if peer.FQDN != "" { + hostPatterns = append(hostPatterns, peer.FQDN) + } + if peer.Hostname != "" && peer.Hostname != peer.FQDN { + hostPatterns = append(hostPatterns, peer.Hostname) + } + return hostPatterns +} + +func (m *Manager) writeSSHConfig(sshConfig string) error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + + if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { + return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) + } + + if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { + return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) + } + + log.Infof("Created NetBird SSH client config: %s", sshConfigPath) + return nil +} + +// RemoveSSHClientConfig removes NetBird SSH configuration +func (m *Manager) RemoveSSHClientConfig() error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + err := os.Remove(sshConfigPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err) + } + if err == nil { + log.Infof("Removed NetBird SSH config: %s", sshConfigPath) + } + return nil +} + +func (m *Manager) getNetBirdExecutablePath() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("retrieve executable path: %w", err) + } + + realPath, err := filepath.EvalSymlinks(execPath) + if err != nil { + log.Debugf("symlink resolution failed: %v", err) + return execPath, nil + } + + return realPath, nil +} + +// GetSSHConfigDir returns the SSH config directory path +func (m *Manager) GetSSHConfigDir() string { + return m.sshConfigDir +} + +// GetSSHConfigFile returns the SSH config file name +func (m *Manager) GetSSHConfigFile() string { + return m.sshConfigFile +} diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go new file mode 100644 index 000000000..dc3ad95b3 --- /dev/null +++ b/client/ssh/config/manager_test.go @@ -0,0 +1,159 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestManager_SetupSSHClientConfig(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Test SSH config generation with peers + peers := []PeerSSHInfo{ + { + Hostname: "peer1", + IP: "100.125.1.1", + FQDN: "peer1.nb.internal", + }, + { + Hostname: "peer2", + IP: "100.125.1.2", + FQDN: "peer2.nb.internal", + }, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Read generated config + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + + configStr := string(content) + + // Verify the basic SSH config structure exists + assert.Contains(t, configStr, "# NetBird SSH client configuration") + assert.Contains(t, configStr, "Generated automatically - do not edit manually") + + // Check that peer hostnames are included + assert.Contains(t, configStr, "100.125.1.1") + assert.Contains(t, configStr, "100.125.1.2") + assert.Contains(t, configStr, "peer1.nb.internal") + assert.Contains(t, configStr, "peer2.nb.internal") + + // Check platform-specific UserKnownHostsFile + if runtime.GOOS == "windows" { + assert.Contains(t, configStr, "UserKnownHostsFile NUL") + } else { + assert.Contains(t, configStr, "UserKnownHostsFile /dev/null") + } +} + +func TestGetSystemSSHConfigDir(t *testing.T) { + configDir := getSystemSSHConfigDir() + + // Path should not be empty + assert.NotEmpty(t, configDir) + + // Should be an absolute path + assert.True(t, filepath.IsAbs(configDir)) + + // On Unix systems, should start with /etc + // On Windows, should contain ProgramData + if runtime.GOOS == "windows" { + assert.Contains(t, strings.ToLower(configDir), "programdata") + } else { + assert.Contains(t, configDir, "/etc/ssh") + } +} + +func TestManager_PeerLimit(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is skipped when too many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should not be created due to peer limit + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") +} + +func TestManager_ForcedSSHConfig(t *testing.T) { + // Set force environment variable + t.Setenv(EnvForceSSHConfig, "true") + + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is forced despite many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should be created despite peer limit due to force flag + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + require.NoError(t, err, "SSH config should be created when forced") + + // Verify config contains peer hostnames + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + assert.Contains(t, configStr, "peer0.nb.internal") + assert.Contains(t, configStr, "peer1.nb.internal") +} diff --git a/client/ssh/config/shutdown_state.go b/client/ssh/config/shutdown_state.go new file mode 100644 index 000000000..22f0e0678 --- /dev/null +++ b/client/ssh/config/shutdown_state.go @@ -0,0 +1,22 @@ +package config + +// ShutdownState represents SSH configuration state that needs to be cleaned up. +type ShutdownState struct { + SSHConfigDir string + SSHConfigFile string +} + +// Name returns the state name for the state manager. +func (s *ShutdownState) Name() string { + return "ssh_config_state" +} + +// Cleanup removes SSH client configuration files. +func (s *ShutdownState) Cleanup() error { + manager := &Manager{ + sshConfigDir: s.SSHConfigDir, + sshConfigFile: s.SSHConfigFile, + } + + return manager.RemoveSSHClientConfig() +} diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go new file mode 100644 index 000000000..487f4665a --- /dev/null +++ b/client/ssh/detection/detection.go @@ -0,0 +1,99 @@ +package detection + +import ( + "bufio" + "context" + "net" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // ServerIdentifier is the base response for NetBird SSH servers + ServerIdentifier = "NetBird-SSH-Server" + // ProxyIdentifier is the base response for NetBird SSH proxy + ProxyIdentifier = "NetBird-SSH-Proxy" + // JWTRequiredMarker is appended to responses when JWT is required + JWTRequiredMarker = "NetBird-JWT-Required" + + // Timeout is the timeout for SSH server detection + Timeout = 5 * time.Second +) + +type ServerType string + +const ( + ServerTypeNetBirdJWT ServerType = "netbird-jwt" + ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt" + ServerTypeRegular ServerType = "regular" +) + +// Dialer provides network connection capabilities +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// RequiresJWT checks if the server type requires JWT authentication +func (s ServerType) RequiresJWT() bool { + return s == ServerTypeNetBirdJWT +} + +// ExitCode returns the exit code for the detect command +func (s ServerType) ExitCode() int { + switch s { + case ServerTypeNetBirdJWT: + return 0 + case ServerTypeNetBirdNoJWT: + return 1 + case ServerTypeRegular: + return 2 + default: + return 2 + } +} + +// DetectSSHServerType detects SSH server type using the provided dialer +func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) { + targetAddr := net.JoinHostPort(host, strconv.Itoa(port)) + + conn, err := dialer.DialContext(ctx, "tcp", targetAddr) + if err != nil { + log.Debugf("SSH connection failed for detection: %v", err) + return ServerTypeRegular, nil + } + defer conn.Close() + + if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { + log.Debugf("set read deadline: %v", err) + return ServerTypeRegular, nil + } + + reader := bufio.NewReader(conn) + serverBanner, err := reader.ReadString('\n') + if err != nil { + log.Debugf("read SSH banner: %v", err) + return ServerTypeRegular, nil + } + + serverBanner = strings.TrimSpace(serverBanner) + log.Debugf("SSH server banner: %s", serverBanner) + + if !strings.HasPrefix(serverBanner, "SSH-") { + log.Debugf("Invalid SSH banner") + return ServerTypeRegular, nil + } + + if !strings.Contains(serverBanner, ServerIdentifier) { + log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier) + return ServerTypeRegular, nil + } + + if strings.Contains(serverBanner, JWTRequiredMarker) { + return ServerTypeNetBirdJWT, nil + } + + return ServerTypeNetBirdNoJWT, nil +} diff --git a/client/ssh/login.go b/client/ssh/login.go deleted file mode 100644 index cb2615e55..000000000 --- a/client/ssh/login.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - - "github.com/netbirdio/netbird/util" -) - -func isRoot() bool { - return os.Geteuid() == 0 -} - -func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { - if !isRoot() { - shell := getUserShell(user) - if shell == "" { - shell = "/bin/sh" - } - - return shell, []string{"-l"}, nil - } - - loginPath, err = exec.LookPath("login") - if err != nil { - return "", nil, err - } - - addrPort, err := netip.ParseAddrPort(remoteAddr.String()) - if err != nil { - return "", nil, err - } - - switch runtime.GOOS { - case "linux": - if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") { - return loginPath, []string{"-f", user, "-p"}, nil - } - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - case "darwin": - return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil - case "freebsd": - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - default: - return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) - } -} diff --git a/client/ssh/lookup.go b/client/ssh/lookup.go deleted file mode 100644 index 9a7f6ff2e..000000000 --- a/client/ssh/lookup.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !darwin -// +build !darwin - -package ssh - -import "os/user" - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - return user.Lookup(username) -} diff --git a/client/ssh/lookup_darwin.go b/client/ssh/lookup_darwin.go deleted file mode 100644 index 913d049dc..000000000 --- a/client/ssh/lookup_darwin.go +++ /dev/null @@ -1,51 +0,0 @@ -//go:build darwin -// +build darwin - -package ssh - -import ( - "bytes" - "fmt" - "os/exec" - "os/user" - "strings" -) - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - var userObject *user.User - userObject, err := user.Lookup(username) - if err != nil && err.Error() == user.UnknownUserError(username).Error() { - return idUserNameLookup(username) - } else if err != nil { - return nil, err - } - - return userObject, nil -} - -func idUserNameLookup(username string) (*user.User, error) { - cmd := exec.Command("id", "-P", username) - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err) - } - colon := ":" - - if !bytes.Contains(out, []byte(username+colon)) { - return nil, fmt.Errorf("unable to find user in returned string") - } - // netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh - parts := strings.SplitN(string(out), colon, 10) - userObject := &user.User{ - Username: parts[0], - Uid: parts[2], - Gid: parts[3], - Name: parts[7], - HomeDir: parts[8], - } - return userObject, nil -} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go new file mode 100644 index 000000000..bc8a84b89 --- /dev/null +++ b/client/ssh/proxy/proxy.go @@ -0,0 +1,392 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/version" +) + +const ( + // sshConnectionTimeout is the timeout for SSH TCP connection establishment + sshConnectionTimeout = 120 * time.Second + // sshHandshakeTimeout is the timeout for SSH handshake completion + sshHandshakeTimeout = 30 * time.Second + + jwtAuthErrorMsg = "JWT authentication: %w" +) + +type SSHProxy struct { + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient +} + +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { + grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") + grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, fmt.Errorf("connect to daemon: %w", err) + } + + return &SSHProxy{ + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + }, nil +} + +func (p *SSHProxy) Close() error { + if p.conn != nil { + return p.conn.Close() + } + return nil +} + +func (p *SSHProxy) Connect(ctx context.Context) error { + hint := profilemanager.GetLoginHint() + + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) + if err != nil { + return fmt.Errorf(jwtAuthErrorMsg, err) + } + + return p.runProxySSHServer(ctx, jwtToken) +} + +func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { + serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) + + sshServer := &ssh.Server{ + Handler: func(s ssh.Session) { + p.handleSSHSession(ctx, s, jwtToken) + }, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": p.directTCPIPHandler, + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": func(s ssh.Session) { + p.sftpSubsystemHandler(s, jwtToken) + }, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": p.tcpipForwardHandler, + "cancel-tcpip-forward": p.cancelTcpipForwardHandler, + }, + Version: serverVersion, + } + + hostKey, err := generateHostKey() + if err != nil { + return fmt.Errorf("generate host key: %w", err) + } + sshServer.HostSigners = []ssh.Signer{hostKey} + + conn := &stdioConn{ + stdin: os.Stdin, + stdout: os.Stdout, + } + + sshServer.HandleConn(conn) + + return nil +} + +func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) + return + } + defer func() { _ = sshClient.Close() }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) + return + } + defer func() { _ = serverSession.Close() }() + + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() + + ptyReq, winCh, isPty := session.Pty() + if isPty { + if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { + log.Debugf("PTY request to backend: %v", err) + } + + go func() { + for win := range winCh { + if err := serverSession.WindowChange(win.Height, win.Width); err != nil { + log.Debugf("window change: %v", err) + } + } + }() + } + + if len(session.Command()) > 0 { + if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + log.Debugf("run command: %v", err) + p.handleProxyExitCode(session, err) + } + return + } + + if err = serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } + if err := serverSession.Wait(); err != nil { + log.Debugf("session wait: %v", err) + p.handleProxyExitCode(session, err) + } +} + +func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { + var exitErr *cryptossh.ExitError + if errors.As(err, &exitErr) { + if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { + log.Debugf("set exit status: %v", exitErr) + } + } +} + +func generateHostKey() (ssh.Signer, error) { + keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, fmt.Errorf("generate ED25519 key: %w", err) + } + + signer, err := cryptossh.ParsePrivateKey(keyPEM) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + return signer, nil +} + +type stdioConn struct { + stdin io.Reader + stdout io.Writer + closed bool + mu sync.Mutex +} + +func (c *stdioConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.EOF + } + c.mu.Unlock() + return c.stdin.Read(b) +} + +func (c *stdioConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.ErrClosedPipe + } + c.mu.Unlock() + return c.stdout.Write(b) +} + +func (c *stdioConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *stdioConn) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { + _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") +} + +func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := sshClient.Close(); err != nil { + log.Debugf("close SSH client: %v", err) + } + }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(s, "create server session: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := serverSession.Close(); err != nil { + log.Debugf("close server session: %v", err) + } + }() + + stdin, stdout, err := p.setupSFTPPipes(serverSession) + if err != nil { + log.Debugf("setup SFTP pipes: %v", err) + _ = s.Exit(1) + return + } + + if err := serverSession.RequestSubsystem("sftp"); err != nil { + _, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err) + _ = s.Exit(1) + return + } + + p.runSFTPBridge(ctx, s, stdin, stdout, serverSession) +} + +func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) { + stdin, err := serverSession.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdin pipe: %w", err) + } + + stdout, err := serverSession.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdout pipe: %w", err) + } + + return stdin, stdout, nil +} + +func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) { + copyErrCh := make(chan error, 2) + + go func() { + _, err := io.Copy(stdin, s) + if err != nil { + log.Debugf("SFTP client to server copy: %v", err) + } + if err := stdin.Close(); err != nil { + log.Debugf("close stdin: %v", err) + } + copyErrCh <- err + }() + + go func() { + _, err := io.Copy(s, stdout) + if err != nil { + log.Debugf("SFTP server to client copy: %v", err) + } + copyErrCh <- err + }() + + go func() { + <-ctx.Done() + if err := serverSession.Close(); err != nil { + log.Debugf("force close server session on context cancellation: %v", err) + } + }() + + for i := 0; i < 2; i++ { + if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) { + log.Debugf("SFTP copy error: %v", err) + } + } + + if err := serverSession.Wait(); err != nil { + log.Debugf("SFTP session ended: %v", err) + } +} + +func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return false, []byte("port forwarding not supported in proxy") +} + +func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return true, nil +} + +func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { + config := &cryptossh.ClientConfig{ + User: user, + Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)}, + Timeout: sshHandshakeTimeout, + HostKeyCallback: p.verifyHostKey, + } + + dialer := &net.Dialer{ + Timeout: sshConnectionTimeout, + } + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("connect to server: %w", err) + } + + clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config) + if err != nil { + _ = conn.Close() + return nil, fmt.Errorf("SSH handshake: %w", err) + } + + return cryptossh.NewClient(clientConn, chans, reqs), nil +} + +func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error { + verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go new file mode 100644 index 000000000..c5036da37 --- /dev/null +++ b/client/ssh/proxy/proxy_test.go @@ -0,0 +1,367 @@ +package proxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" +) + +func TestMain(m *testing.M) { + if len(os.Args) > 2 && os.Args[1] == "ssh" { + if os.Args[2] == "exec" { + if len(os.Args) > 3 { + cmd := os.Args[3] + if cmd == "echo" && len(os.Args) > 4 { + fmt.Fprintln(os.Stdout, os.Args[4]) + os.Exit(0) + } + } + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args) + os.Exit(1) + } + } + + code := m.Run() + + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHProxy_verifyHostKey(t *testing.T) { + t.Run("calls daemon to verify host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + testPubKey, err := nbssh.GeneratePublicKey(testKey) + require.NoError(t, err) + + mockDaemon.setHostKey("test-host", testPubKey) + + err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey)) + assert.NoError(t, err) + }) + + t.Run("rejects unknown host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey) + require.NoError(t, err) + + err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "peer unknown-host not found in network") + }) +} + +func TestSSHProxy_Connect(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // TODO: Windows test times out - user switching and command execution tested on Linux + if runtime.GOOS == "windows" { + t.Skip("Skipping on Windows - covered by Linux tests") + } + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + sshServerAddr := server.StartTestServer(t, sshServer) + defer func() { _ = sshServer.Stop() }() + + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, nil) + require.NoError(t, err) + + clientConn, proxyConn := net.Pipe() + defer func() { _ = clientConn.Close() }() + + origStdin := os.Stdin + origStdout := os.Stdout + defer func() { + os.Stdin = origStdin + os.Stdout = origStdout + }() + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + go func() { + _, _ = io.Copy(stdinWriter, proxyConn) + }() + go func() { + _, _ = io.Copy(proxyConn, stdoutReader) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + connectErrCh := make(chan error, 1) + go func() { + connectErrCh <- proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 3 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err, "Should connect to proxy server") + defer func() { _ = sshClientConn.Close() }() + + sshClient := cryptossh.NewClient(sshClientConn, chans, reqs) + + session, err := sshClient.NewSession() + require.NoError(t, err, "Should create session through full proxy to backend") + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output("echo hello-from-proxy") + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + require.NoError(t, err, "Command should execute successfully through proxy") + assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy") + case <-time.After(3 * time.Second): + t.Fatal("Command execution timed out") + } + + _ = session.Close() + _ = sshClient.Close() + _ = clientConn.Close() + cancel() +} + +type mockDaemonServer struct { + proto.UnimplementedDaemonServiceServer + hostKeys map[string][]byte + jwtToken string +} + +func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) { + key, found := m.hostKeys[req.PeerAddress] + return &proto.GetPeerSSHHostKeyResponse{ + Found: found, + SshHostKey: key, + }, nil +} + +func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) { + return &proto.RequestJWTAuthResponse{ + CachedToken: m.jwtToken, + }, nil +} + +func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) { + return &proto.WaitJWTTokenResponse{ + Token: m.jwtToken, + }, nil +} + +type mockDaemon struct { + addr string + server *grpc.Server + impl *mockDaemonServer +} + +func startMockDaemon(t *testing.T) *mockDaemon { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + impl := &mockDaemonServer{ + hostKeys: make(map[string][]byte), + jwtToken: "test-jwt-token", + } + + grpcServer := grpc.NewServer() + proto.RegisterDaemonServiceServer(grpcServer, impl) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return &mockDaemon{ + addr: listener.Addr().String(), + server: grpcServer, + impl: impl, + } +} + +func (m *mockDaemon) setHostKey(addr string, pubKey []byte) { + m.impl.hostKeys[addr] = pubKey +} + +func (m *mockDaemon) setJWTToken(token string) { + m.impl.jwtToken = token +} + +func (m *mockDaemon) stop() { + if m.server != nil { + m.server.Stop() + } +} + +func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey { + t.Helper() + pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes) + require.NoError(t, err) + return pubKey +} + +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + t.Helper() + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(n), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { + t.Helper() + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} diff --git a/client/ssh/server.go b/client/ssh/server.go deleted file mode 100644 index 8c5db2547..000000000 --- a/client/ssh/server.go +++ /dev/null @@ -1,280 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "io" - "net" - "os" - "os/exec" - "os/user" - "runtime" - "strings" - "sync" - "time" - - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - log "github.com/sirupsen/logrus" -) - -// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server -const DefaultSSHPort = 44338 - -// TerminalTimeout is the timeout for terminal session to be ready -const TerminalTimeout = 10 * time.Second - -// TerminalBackoffDelay is the delay between terminal session readiness checks -const TerminalBackoffDelay = 500 * time.Millisecond - -// DefaultSSHServer is a function that creates DefaultServer -func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { - return newDefaultServer(hostKeyPEM, addr) -} - -// Server is an interface of SSH server -type Server interface { - // Stop stops SSH server. - Stop() error - // Start starts SSH server. Blocking - Start() error - // RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys - RemoveAuthorizedKey(peer string) - // AddAuthorizedKey add a given peer key to server authorized keys - AddAuthorizedKey(peer, newKey string) error -} - -// DefaultServer is the embedded NetBird SSH server -type DefaultServer struct { - listener net.Listener - // authorizedKeys is ssh pub key indexed by peer WireGuard public key - authorizedKeys map[string]ssh.PublicKey - mu sync.Mutex - hostKeyPEM []byte - sessions []ssh.Session -} - -// newDefaultServer creates new server with provided host key -func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - allowedKeys := make(map[string]ssh.PublicKey) - return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *DefaultServer) RemoveAuthorizedKey(peer string) { - srv.mu.Lock() - defer srv.mu.Unlock() - - delete(srv.authorizedKeys, peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error { - srv.mu.Lock() - defer srv.mu.Unlock() - - parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey)) - if err != nil { - return err - } - - srv.authorizedKeys[peer] = parsedKey - return nil -} - -// Stop stops SSH server. -func (srv *DefaultServer) Stop() error { - srv.mu.Lock() - defer srv.mu.Unlock() - err := srv.listener.Close() - if err != nil { - return err - } - for _, session := range srv.sessions { - err := session.Close() - if err != nil { - log.Warnf("failed closing SSH session from %v", err) - } - } - - return nil -} - -func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { - srv.mu.Lock() - defer srv.mu.Unlock() - - for _, allowed := range srv.authorizedKeys { - if ssh.KeysEqual(allowed, key) { - return true - } - } - - return false -} - -func prepareUserEnv(user *user.User, shell string) []string { - return []string{ - fmt.Sprint("SHELL=" + shell), - fmt.Sprint("USER=" + user.Username), - fmt.Sprint("HOME=" + user.HomeDir), - } -} - -func acceptEnv(s string) bool { - split := strings.Split(s, "=") - if len(split) != 2 { - return false - } - return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_") -} - -// sessionHandler handles SSH session post auth -func (srv *DefaultServer) sessionHandler(session ssh.Session) { - srv.mu.Lock() - srv.sessions = append(srv.sessions, session) - srv.mu.Unlock() - - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String()) - - localUser, err := userNameLookup(session.User()) - if err != nil { - _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint - err = session.Exit(1) - if err != nil { - return - } - log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User()) - return - } - - ptyReq, winCh, isPty := session.Pty() - if isPty { - loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr()) - if err != nil { - log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String()) - return - } - cmd := exec.Command(loginCmd, loginArgs...) - go func() { - <-session.Context().Done() - if cmd.Process == nil { - return - } - err := cmd.Process.Kill() - if err != nil { - log.Debugf("failed killing SSH process %v", err) - return - } - }() - cmd.Dir = localUser.HomeDir - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...) - for _, v := range session.Environ() { - if acceptEnv(v) { - cmd.Env = append(cmd.Env, v) - } - } - - log.Debugf("Login command: %s", cmd.String()) - file, err := pty.Start(cmd) - if err != nil { - log.Errorf("failed starting SSH server: %v", err) - } - - go func() { - for win := range winCh { - setWinSize(file, win.Width, win.Height) - } - }() - - srv.stdInOut(file, session) - - err = cmd.Wait() - if err != nil { - return - } - } else { - _, err := io.WriteString(session, "only PTY is supported.\n") - if err != nil { - return - } - err = session.Exit(1) - if err != nil { - return - } - } - log.Debugf("SSH session ended") -} - -func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { - go func() { - // stdin - _, err := io.Copy(file, session) - if err != nil { - _ = session.Exit(1) - return - } - }() - - // AWS Linux 2 machines need some time to open the terminal so we need to wait for it - timer := time.NewTimer(TerminalTimeout) - for { - select { - case <-timer.C: - _, _ = session.Write([]byte("Reached timeout while opening connection\n")) - _ = session.Exit(1) - return - default: - // stdout - writtenBytes, err := io.Copy(session, file) - if err != nil && writtenBytes != 0 { - _ = session.Exit(0) - return - } - time.Sleep(TerminalBackoffDelay) - } - } -} - -// Start starts SSH server. Blocking -func (srv *DefaultServer) Start() error { - log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String()) - - publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler) - hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM) - err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM) - if err != nil { - return err - } - - return nil -} - -func getUserShell(userID string) string { - if runtime.GOOS == "linux" { - output, _ := exec.Command("getent", "passwd", userID).Output() - line := strings.SplitN(string(output), ":", 10) - if len(line) > 6 { - return strings.TrimSpace(line[6]) - } - } - - shell := os.Getenv("SHELL") - if shell == "" { - shell = "/bin/sh" - } - return shell -} diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go new file mode 100644 index 000000000..7a01ce4f6 --- /dev/null +++ b/client/ssh/server/command_execution.go @@ -0,0 +1,206 @@ +package server + +import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handleCommand executes an SSH command with privilege validation +func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { + hasPty := winCh != nil + + commandType := "command" + if hasPty { + commandType = "Pty command" + } + + logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) + + execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + if err != nil { + logger.Errorf("%s creation failed: %v", commandType, err) + + errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType) + if hasPty { + errorMsg += " with Pty" + } + errorMsg += "\n" + + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return + } + + if !hasPty { + if s.executeCommand(logger, session, execCmd, cleanup) { + logger.Debugf("%s execution completed", commandType) + } + return + } + + defer cleanup() + + ptyReq, _, _ := session.Pty() + if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { + logger.Debugf("%s execution completed", commandType) + } +} + +func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, nil, errors.New("no user in privilege result") + } + + // If PTY requested but su doesn't support --pty, skip su and use executor + // This ensures PTY functionality is provided (executor runs within our allocated PTY) + if hasPty && !s.suSupportsPty { + log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + // Try su first for system integration (PAM/audit) when privileged + cmd, err := s.createSuCommand(session, localUser, hasPty) + if err != nil || privilegeResult.UsedFallback { + log.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, func() {}, nil +} + +// executeCommand executes the command and handles I/O and exit codes +func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool { + defer cleanup() + + s.setupProcessGroup(execCmd) + + stdinPipe, err := execCmd.StdinPipe() + if err != nil { + logger.Errorf("create stdin pipe: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + execCmd.Stdout = session + execCmd.Stderr = session.Stderr() + + if execCmd.Dir != "" { + if _, err := os.Stat(execCmd.Dir); err != nil { + logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err) + execCmd.Dir = "/" + } + } + + if err := execCmd.Start(); err != nil { + logger.Errorf("command start failed: %v", err) + // no user message for exec failure, just exit + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + go s.handleCommandIO(logger, stdinPipe, session) + return s.waitForCommandCleanup(logger, session, execCmd) +} + +// handleCommandIO manages stdin/stdout copying in a goroutine +func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) { + defer func() { + if err := stdinPipe.Close(); err != nil { + logger.Debugf("stdin pipe close error: %v", err) + } + }() + if _, err := io.Copy(stdinPipe, session); err != nil { + logger.Debugf("stdin copy error: %v", err) + } +} + +// waitForCommandCleanup waits for command completion with session disconnect handling +func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + logger.Debugf("session cancelled, terminating command") + s.killProcessGroup(execCmd) + + select { + case err := <-done: + logger.Tracef("command terminated after session cancellation: %v", err) + case <-time.After(5 * time.Second): + logger.Warnf("command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } + return false + + case err := <-done: + return s.handleCommandCompletion(logger, session, err) + } +} + +// handleCommandCompletion handles command completion +func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool { + if err != nil { + logger.Debugf("command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return false + } + + s.handleSessionExit(session, nil, logger) + return true +} + +// handleSessionExit handles command errors and sets appropriate exit codes +func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) { + if err == nil { + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } + return + } + + var exitError *exec.ExitError + if errors.As(err, &exitError) { + if err := session.Exit(exitError.ExitCode()); err != nil { + logSessionExitError(logger, err) + } + } else { + logger.Debugf("non-exit error in command execution: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + } +} diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go new file mode 100644 index 000000000..6473f8273 --- /dev/null +++ b/client/ssh/server/command_execution_js.go @@ -0,0 +1,52 @@ +//go:build js + +package server + +import ( + "context" + "errors" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") + +// createSuCommand is not supported on JS/WASM +func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { + return nil, errNotSupported +} + +// createExecutorCommand is not supported on JS/WASM +func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { + return nil, nil, errNotSupported +} + +// prepareCommandEnv is not supported on JS/WASM +func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { + return nil +} + +// setupProcessGroup is not supported on JS/WASM +func (s *Server) setupProcessGroup(_ *exec.Cmd) { +} + +// killProcessGroup is not supported on JS/WASM +func (s *Server) killProcessGroup(*exec.Cmd) { +} + +// detectSuPtySupport always returns false on JS/WASM +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// executeCommandWithPty is not supported on JS/WASM +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + logger.Errorf("PTY command execution not supported on JS/WASM") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go new file mode 100644 index 000000000..da059fed9 --- /dev/null +++ b/client/ssh/server/command_execution_unix.go @@ -0,0 +1,329 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "os/user" + "strings" + "sync" + "syscall" + "time" + + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// ptyManager manages Pty file operations with thread safety +type ptyManager struct { + file *os.File + mu sync.RWMutex + closed bool + closeErr error + once sync.Once +} + +func newPtyManager(file *os.File) *ptyManager { + return &ptyManager{file: file} +} + +func (pm *ptyManager) Close() error { + pm.once.Do(func() { + pm.mu.Lock() + pm.closed = true + pm.closeErr = pm.file.Close() + pm.mu.Unlock() + }) + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.closeErr +} + +func (pm *ptyManager) Setsize(ws *pty.Winsize) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + if pm.closed { + return errors.New("pty is closed") + } + return pty.Setsize(pm.file, ws) +} + +func (pm *ptyManager) File() *os.File { + return pm.file +} + +// detectSuPtySupport checks if su supports the --pty flag +func (s *Server) detectSuPtySupport(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "su", "--help") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("su --help failed (may not support --help): %v", err) + return false + } + + supported := strings.Contains(string(output), "--pty") + log.Debugf("su --pty support detected: %v", supported) + return supported +} + +// createSuCommand creates a command using su -l -c for privilege switching +func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + suPath, err := exec.LookPath("su") + if err != nil { + return nil, fmt.Errorf("su command not available: %w", err) + } + + command := session.RawCommand() + if command == "" { + return nil, fmt.Errorf("no command specified for su execution") + } + + args := []string{"-l"} + if hasPty && s.suSupportsPty { + args = append(args, "--pty") + } + args = append(args, localUser.Username, "-c", command) + + cmd := exec.CommandContext(session.Context(), suPath, args...) + cmd.Dir = localUser.HomeDir + + return cmd, nil +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-l"} + } + return []string{shell, "-l", "-c", cmdString} +} + +// prepareCommandEnv prepares environment variables for command execution on Unix +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +// executeCommandWithPty executes a command with PTY allocation +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType)) + + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) + if err != nil { + logger.Errorf("Pty command creation failed: %v", err) + errorMsg := "User switching failed - login command not available\r\n" + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Infof("starting interactive shell: %s", execCmd.Path) + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +// runPtyCommand runs a command with PTY management (common code for interactive and command execution) +func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) + if err != nil { + logger.Errorf("Pty start failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + ptyMgr := newPtyManager(ptmx) + defer func() { + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close error: %v", err) + } + }() + + go s.handlePtyWindowResize(logger, session, ptyMgr, winCh) + s.handlePtyIO(logger, session, ptyMgr) + s.waitForPtyCompletion(logger, session, execCmd, ptyMgr) + return true +} + +func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) { + winSize := &pty.Winsize{ + Cols: uint16(ptyReq.Window.Width), + Rows: uint16(ptyReq.Window.Height), + } + if winSize.Cols == 0 { + winSize.Cols = 80 + } + if winSize.Rows == 0 { + winSize.Rows = 24 + } + + ptmx, err := pty.StartWithSize(execCmd, winSize) + if err != nil { + return nil, fmt.Errorf("start Pty: %w", err) + } + + return ptmx, nil +} + +func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) { + for { + select { + case <-session.Context().Done(): + return + case win, ok := <-winCh: + if !ok { + return + } + if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil { + logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err) + } + } + } +} + +func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) { + ptmx := ptyMgr.File() + + go func() { + if _, err := io.Copy(ptmx, session); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty input copy error: %v", err) + } + } + }() + + go func() { + defer func() { + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Debugf("session close error: %v", err) + } + }() + if _, err := io.Copy(session, ptmx); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty output copy error: %v", err) + } + } + }() +} + +func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) + case err := <-done: + s.handlePtyCommandCompletion(logger, session, err) + } +} + +func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) { + logger.Debugf("Pty session cancelled, terminating command") + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close during session cancellation: %v", err) + } + + s.killProcessGroup(execCmd) + + select { + case err := <-done: + if err != nil { + logger.Debugf("Pty command terminated after session cancellation with error: %v", err) + } else { + logger.Debugf("Pty command terminated after session cancellation") + } + case <-time.After(5 * time.Second): + logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { + if err != nil { + logger.Debugf("Pty command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return + } + + // Normal completion + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) setupProcessGroup(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + pgid := cmd.Process.Pid + + if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { + logger.Debugf("kill process group SIGTERM: %v", err) + return + } + + const gracePeriod = 500 * time.Millisecond + const checkInterval = 50 * time.Millisecond + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + timeout := time.After(gracePeriod) + + for { + select { + case <-timeout: + if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil { + logger.Debugf("kill process group SIGKILL: %v", err) + } + return + case <-ticker.C: + if err := syscall.Kill(-pgid, 0); err != nil { + return + } + } + } +} diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go new file mode 100644 index 000000000..37b3ae0ee --- /dev/null +++ b/client/ssh/server/command_execution_windows.go @@ -0,0 +1,430 @@ +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/ssh/server/winpty" +) + +// getUserEnvironment retrieves the Windows environment for the target user. +// Follows OpenSSH's resilient approach with graceful degradation on failures. +func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { + userToken, err := s.getUserToken(username, domain) + if err != nil { + return nil, fmt.Errorf("get user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + return s.getUserEnvironmentWithToken(userToken, username, domain) +} + +// getUserEnvironmentWithToken retrieves the Windows environment using an existing token. +func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { + userProfile, err := s.loadUserProfile(userToken, username, domain) + if err != nil { + log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + userProfile = fmt.Sprintf("C:\\Users\\%s", username) + } + + envMap := make(map[string]string) + + if err := s.loadSystemEnvironment(envMap); err != nil { + log.Debugf("failed to load system environment from registry: %v", err) + } + + s.setUserEnvironmentVariables(envMap, userProfile, username, domain) + + var env []string + for key, value := range envMap { + env = append(env, key+"="+value) + } + + return env, nil +} + +// getUserToken creates a user token for the specified user. +func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper() + token, err := privilegeDropper.createToken(username, domain) + if err != nil { + return 0, fmt.Errorf("generate S4U user token: %w", err) + } + return token, nil +} + +// loadUserProfile loads the Windows user profile and returns the profile path. +func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) { + usernamePtr, err := windows.UTF16PtrFromString(username) + if err != nil { + return "", fmt.Errorf("convert username to UTF-16: %w", err) + } + + var domainUTF16 *uint16 + if domain != "" && domain != "." { + domainUTF16, err = windows.UTF16PtrFromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to UTF-16: %w", err) + } + } + + type profileInfo struct { + dwSize uint32 + dwFlags uint32 + lpUserName *uint16 + lpProfilePath *uint16 + lpDefaultPath *uint16 + lpServerName *uint16 + lpPolicyPath *uint16 + hProfile windows.Handle + } + + const PI_NOUI = 0x00000001 + + profile := profileInfo{ + dwSize: uint32(unsafe.Sizeof(profileInfo{})), + dwFlags: PI_NOUI, + lpUserName: usernamePtr, + lpServerName: domainUTF16, + } + + userenv := windows.NewLazySystemDLL("userenv.dll") + loadUserProfileW := userenv.NewProc("LoadUserProfileW") + + ret, _, err := loadUserProfileW.Call( + uintptr(userToken), + uintptr(unsafe.Pointer(&profile)), + ) + + if ret == 0 { + return "", fmt.Errorf("LoadUserProfileW: %w", err) + } + + if profile.lpProfilePath == nil { + return "", fmt.Errorf("LoadUserProfileW returned null profile path") + } + + profilePath := windows.UTF16PtrToString(profile.lpProfilePath) + return profilePath, nil +} + +// loadSystemEnvironment loads system-wide environment variables from registry. +func (s *Server) loadSystemEnvironment(envMap map[string]string) error { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, + `SYSTEM\CurrentControlSet\Control\Session Manager\Environment`, + registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("open system environment registry key: %w", err) + } + defer func() { + if err := key.Close(); err != nil { + log.Debugf("close registry key: %v", err) + } + }() + + return s.readRegistryEnvironment(key, envMap) +} + +// readRegistryEnvironment reads environment variables from a registry key. +func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error { + names, err := key.ReadValueNames(0) + if err != nil { + return fmt.Errorf("read registry value names: %w", err) + } + + for _, name := range names { + value, valueType, err := key.GetStringValue(name) + if err != nil { + log.Debugf("failed to read registry value %s: %v", name, err) + continue + } + + finalValue := s.expandRegistryValue(value, valueType, name) + s.setEnvironmentVariable(envMap, name, finalValue) + } + + return nil +} + +// expandRegistryValue expands registry values if they contain environment variables. +func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string { + if valueType != registry.EXPAND_SZ { + return value + } + + sourcePtr := windows.StringToUTF16Ptr(value) + expandedBuffer := make([]uint16, 1024) + expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s: %v", name, err) + return value + } + + // If buffer was too small, retry with larger buffer + if expandedLen > uint32(len(expandedBuffer)) { + expandedBuffer = make([]uint16, expandedLen) + expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s on retry: %v", name, err) + return value + } + } + + if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) { + return windows.UTF16ToString(expandedBuffer[:expandedLen-1]) + } + return value +} + +// setEnvironmentVariable sets an environment variable with special handling for PATH. +func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) { + upperName := strings.ToUpper(name) + + if upperName == "PATH" { + if existing, exists := envMap["PATH"]; exists && existing != value { + envMap["PATH"] = existing + ";" + value + } else { + envMap["PATH"] = value + } + } else { + envMap[upperName] = value + } +} + +// setUserEnvironmentVariables sets critical user-specific environment variables. +func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) { + envMap["USERPROFILE"] = userProfile + + if len(userProfile) >= 2 && userProfile[1] == ':' { + envMap["HOMEDRIVE"] = userProfile[:2] + envMap["HOMEPATH"] = userProfile[2:] + } + + envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming") + envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local") + + tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp") + envMap["TEMP"] = tempDir + envMap["TMP"] = tempDir + + envMap["USERNAME"] = username + if domain != "" && domain != "." { + envMap["USERDOMAIN"] = domain + envMap["USERDNSDOMAIN"] = domain + } + + systemVars := []string{ + "PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION", + "SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT", + "PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC", + } + + for _, sysVar := range systemVars { + if sysValue := os.Getenv(sysVar); sysValue != "" { + envMap[sysVar] = sysValue + } + } +} + +// prepareCommandEnv prepares environment variables for command execution on Windows +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + username, domain := s.parseUsername(localUser.Username) + userEnv, err := s.getUserEnvironment(username, domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env + } + + env := userEnv + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + if privilegeResult.User == nil { + logger.Errorf("no user in privilege result") + return false + } + + cmd := session.Command() + shell := getUserShell(privilegeResult.User.Uid) + + if len(cmd) == 0 { + logger.Infof("starting interactive shell: %s", shell) + } else { + logger.Infof("executing command: %s", safeLogCommand(cmd)) + } + + s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + return true +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-NoLogo"} + } + return []string{shell, "-Command", cmdString} +} + +func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { + logger.Info("starting interactive shell") + s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) +} + +type PtyExecutionRequest struct { + Shell string + Command string + Width int + Height int + Username string + Domain string +} + +func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { + log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", + req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) + + privilegeDropper := NewPrivilegeDropper() + userToken, err := privilegeDropper.createToken(req.Username, req.Domain) + if err != nil { + return fmt.Errorf("create user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + server := &Server{} + userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + userEnv = os.Environ() + } + + workingDir := getUserHomeFromEnv(userEnv) + if workingDir == "" { + workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username) + } + + ptyConfig := winpty.PtyConfig{ + Shell: req.Shell, + Command: req.Command, + Width: req.Width, + Height: req.Height, + WorkingDir: workingDir, + } + + userConfig := winpty.UserConfig{ + Token: userToken, + Environment: userEnv, + } + + log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) +} + +func getUserHomeFromEnv(env []string) string { + for _, envVar := range env { + if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" { + return envVar[12:] + } + } + return "" +} + +func (s *Server) setupProcessGroup(_ *exec.Cmd) { + // Windows doesn't support process groups in the same way as Unix + // Process creation groups are handled differently +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + + if err := cmd.Process.Kill(); err != nil { + logger.Debugf("kill process failed: %v", err) + } +} + +// detectSuPtySupport always returns false on Windows as su is not available +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + command := session.RawCommand() + if command == "" { + logger.Error("no command specified for PTY execution") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) +} + +// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) +func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { + localUser := privilegeResult.User + if localUser == nil { + logger.Errorf("no user in privilege result") + return false + } + + username, domain := s.parseUsername(localUser.Username) + shell := getUserShell(localUser.Uid) + + req := PtyExecutionRequest{ + Shell: shell, + Command: command, + Width: ptyReq.Window.Width, + Height: ptyReq.Window.Height, + Username: username, + Domain: domain, + } + + if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + logger.Errorf("ConPty execution failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Debug("ConPty execution completed") + return true +} diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go new file mode 100644 index 000000000..34ffccfd2 --- /dev/null +++ b/client/ssh/server/compatibility_test.go @@ -0,0 +1,722 @@ +package server + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client +func TestSSHServerCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH compatibility tests in short mode") + } + + // Check if ssh binary is available + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Set up SSH server - use our existing key generation for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate OpenSSH-compatible keys for client + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Create temporary key files for SSH client + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + // Extract host and port from server address + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Get appropriate user for SSH connection (handle system accounts) + username := testutil.GetTestUsername(t) + + t.Run("basic command execution", func(t *testing.T) { + testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username) + }) + + t.Run("interactive command", func(t *testing.T) { + testSSHInteractiveCommand(t, host, portStr, clientKeyFile) + }) + + t.Run("port forwarding", func(t *testing.T) { + testSSHPortForwarding(t, host, portStr, clientKeyFile) + }) +} + +// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user. +func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) { + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "hello_world") + + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("SSH command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully") +} + +// testSSHInteractiveCommand tests interactive shell session. +func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host)) + + stdin, err := cmd.StdinPipe() + if err != nil { + t.Skipf("Cannot create stdin pipe: %v", err) + return + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Skipf("Cannot create stdout pipe: %v", err) + return + } + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH session: %v", err) + return + } + + go func() { + defer func() { + if err := stdin.Close(); err != nil { + t.Logf("stdin close error: %v", err) + } + }() + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("exit\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + }() + + output, err := io.ReadAll(stdout) + if err != nil { + t.Logf("Cannot read SSH output: %v", err) + } + + err = cmd.Wait() + if err != nil { + t.Logf("SSH interactive session error: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work") +} + +// testSSHPortForwarding tests port forwarding compatibility. +func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer testServer.Close() + + testServerAddr := testServer.Addr().String() + expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("test server connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("Test server read error: %v", err) + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("Test server write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + localListener.Close() + + _, localPort, err := net.SplitHostPort(localAddr) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr) + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-L", forwardSpec, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-N", + fmt.Sprintf("%s@%s", username, host)) + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH port forwarding: %v", err) + return + } + + defer func() { + if cmd.Process != nil { + if err := cmd.Process.Kill(); err != nil { + t.Logf("process kill error: %v", err) + } + } + if err := cmd.Wait(); err != nil { + t.Logf("process wait after kill: %v", err) + } + }() + + time.Sleep(500 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second) + if err != nil { + t.Logf("Cannot connect to forwarded port: %v", err) + return + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("forwarded connection close error: %v", err) + } + }() + + request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + _, err = conn.Write([]byte(request)) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + log.Debugf("failed to set read deadline: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + if err != nil { + t.Logf("Cannot read forwarded response: %v", err) + return + } + + assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes") + assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding") +} + +// isSSHClientAvailable checks if the ssh binary is available +func isSSHClientAvailable() bool { + _, err := exec.LookPath("ssh") + return err == nil +} + +// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use. +func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) { + // Check if ssh-keygen is available + if _, err := exec.LookPath("ssh-keygen"); err != nil { + // Fall back to our existing key generation and try to convert + return generateOpenSSHKeyFallback() + } + + // Create temporary file for ssh-keygen + tempFile, err := os.CreateTemp("", "ssh_keygen_*") + if err != nil { + return nil, nil, fmt.Errorf("create temp file: %w", err) + } + keyPath := tempFile.Name() + tempFile.Close() + + // Remove the temp file so ssh-keygen can create it + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to remove key file: %v", err) + } + + // Clean up temp files + defer func() { + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to cleanup key file: %v", err) + } + if err := os.Remove(keyPath + ".pub"); err != nil { + t.Logf("failed to cleanup public key file: %v", err) + } + }() + + // Generate key using ssh-keygen + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output)) + } + + // Read private key + privKeyBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read private key: %w", err) + } + + // Read public key + pubKeyBytes, err := os.ReadFile(keyPath + ".pub") + if err != nil { + return nil, nil, fmt.Errorf("read public key: %w", err) + } + + return privKeyBytes, pubKeyBytes, nil +} + +// generateOpenSSHKeyFallback falls back to generating keys using our existing method +func generateOpenSSHKeyFallback() ([]byte, []byte, error) { + // Generate shared.ED25519 key pair using our existing method + _, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("generate key: %w", err) + } + + // Convert to SSH format + sshPrivKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + return nil, nil, fmt.Errorf("create signer: %w", err) + } + + // For the fallback, just use our PKCS#8 format and hope it works + // This won't be in OpenSSH format but might still work with some SSH clients + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, nil, fmt.Errorf("generate fallback key: %w", err) + } + + // Get public key in SSH format + sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey()) + + return hostKey, sshPubKey, nil +} + +// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes +func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) { + t.Helper() + + tempFile, err := os.CreateTemp("", "ssh_test_key_*") + require.NoError(t, err) + + _, err = tempFile.Write(keyBytes) + require.NoError(t, err) + + err = tempFile.Close() + require.NoError(t, err) + + // Set proper permissions for SSH key (readable by owner only) + err = os.Chmod(tempFile.Name(), 0600) + require.NoError(t, err) + + cleanup := func() { + _ = os.Remove(tempFile.Name()) + } + + return tempFile.Name(), cleanup +} + +// createTempKeyFile creates a temporary SSH private key file (for backward compatibility) +func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { + return createTempKeyFileFromBytes(t, privateKey) +} + +// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility +func TestSSHServerFeatureCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH feature compatibility tests in short mode") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Test various SSH features + testCases := []struct { + name string + testFunc func(t *testing.T, host, port, keyFile string) + description string + }{ + { + name: "command_with_flags", + testFunc: testCommandWithFlags, + description: "Commands with flags should work like standard SSH", + }, + { + name: "environment_variables", + testFunc: testEnvironmentVariables, + description: "Environment variables should be available", + }, + { + name: "exit_codes", + testFunc: testExitCodes, + description: "Exit codes should be properly handled", + }, + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.testFunc(t, host, portStr, clientKeyFile) + }) + } +} + +// testCommandWithFlags tests that commands with flags work properly +func testCommandWithFlags(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test ls with flags + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "ls", "-la", "/tmp") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command with flags failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // Should not be empty and should not contain error messages + assert.NotEmpty(t, string(output), "ls -la should produce output") + assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed") +} + +// testEnvironmentVariables tests that environment is properly set up +func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "$HOME") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Environment test failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // HOME environment variable should be available + homeOutput := strings.TrimSpace(string(output)) + assert.NotEmpty(t, homeOutput, "HOME environment variable should be set") + assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded") +} + +// testExitCodes tests that exit codes are properly handled +func testExitCodes(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test successful command (exit code 0) + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "true") // always succeeds + + err := cmd.Run() + assert.NoError(t, err, "Command with exit code 0 should succeed") + + // Test failing command (exit code 1) + cmd = exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "false") // always fails + + err = cmd.Run() + assert.Error(t, err, "Command with exit code 1 should fail") + + // Check if it's the right kind of error + if exitError, ok := err.(*exec.ExitError); ok { + assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved") + } +} + +// TestSSHServerSecurityFeatures tests security-related SSH features +func TestSSHServerSecurityFeatures(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH security tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server with specific security settings + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + t.Run("key_authentication", func(t *testing.T) { + // Test that key authentication works + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "auth_success") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Key authentication failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "auth_success", "Key authentication should work") + }) + + t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) { + // Create a different key that shouldn't be accepted + wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey) + defer cleanupWrongKey() + + // Test that wrong key is rejected + cmd := exec.Command("ssh", + "-i", wrongKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "should_not_work") + + err = cmd.Run() + assert.NoError(t, err, "Any key should work in no-auth mode") + }) +} + +// TestCrossPlatformCompatibility tests cross-platform behavior +func TestCrossPlatformCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping cross-platform compatibility tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Test platform-specific commands + var testCommand string + + switch runtime.GOOS { + case "windows": + testCommand = "echo %OS%" + default: + testCommand = "uname" + } + + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + testCommand) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Platform-specific command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + outputStr := strings.TrimSpace(string(output)) + t.Logf("Platform command output: %s", outputStr) + assert.NotEmpty(t, outputStr, "Platform-specific command should produce output") +} diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go new file mode 100644 index 000000000..8adc824ef --- /dev/null +++ b/client/ssh/server/executor_unix.go @@ -0,0 +1,253 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// Exit codes for executor process communication +const ( + ExitCodeSuccess = 0 + ExitCodePrivilegeDropFail = 10 + ExitCodeShellExecFail = 11 + ExitCodeValidationFail = 12 +) + +// ExecutorConfig holds configuration for the executor process +type ExecutorConfig struct { + UID uint32 + GID uint32 + Groups []uint32 + WorkingDir string + Shell string + Command string + PTY bool +} + +// PrivilegeDropper handles secure privilege dropping in child processes +type PrivilegeDropper struct{} + +// NewPrivilegeDropper creates a new privilege dropper +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping +func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("get netbird executable path: %w", err) + } + + if err := pd.validatePrivileges(config.UID, config.GID); err != nil { + return nil, fmt.Errorf("invalid privileges: %w", err) + } + + args := []string{ + "ssh", "exec", + "--uid", fmt.Sprintf("%d", config.UID), + "--gid", fmt.Sprintf("%d", config.GID), + "--working-dir", config.WorkingDir, + "--shell", config.Shell, + } + + for _, group := range config.Groups { + args = append(args, "--groups", fmt.Sprintf("%d", group)) + } + + if config.PTY { + args = append(args, "--pty") + } + + if config.Command != "" { + args = append(args, "--cmd", config.Command) + } + + // Log executor args safely - show all args except hide the command value + safeArgs := make([]string, len(args)) + copy(safeArgs, args) + for i := 0; i < len(safeArgs)-1; i++ { + if safeArgs[i] == "--cmd" { + cmdParts := strings.Fields(safeArgs[i+1]) + safeArgs[i+1] = safeLogCommand(cmdParts) + break + } + } + log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + return exec.CommandContext(ctx, netbirdPath, args...), nil +} + +// DropPrivileges performs privilege dropping with thread locking for security +func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + if err := pd.validatePrivileges(targetUID, targetGID); err != nil { + return fmt.Errorf("invalid privileges: %w", err) + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originalUID := os.Geteuid() + originalGID := os.Getegid() + + if originalUID != int(targetUID) || originalGID != int(targetGID) { + if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil { + return fmt.Errorf("set groups and IDs: %w", err) + } + } + + if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID) + return nil +} + +// setGroupsAndIDs sets the supplementary groups, GID, and UID +func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + groups := make([]int, len(supplementaryGroups)) + for i, g := range supplementaryGroups { + groups[i] = int(g) + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" { + if len(groups) == 0 || groups[0] != int(targetGID) { + groups = append([]int{int(targetGID)}, groups...) + } + } + + if err := syscall.Setgroups(groups); err != nil { + return fmt.Errorf("setgroups to %v: %w", groups, err) + } + + if err := syscall.Setgid(int(targetGID)); err != nil { + return fmt.Errorf("setgid to %d: %w", targetGID, err) + } + + if err := syscall.Setuid(int(targetUID)); err != nil { + return fmt.Errorf("setuid to %d: %w", targetUID, err) + } + + return nil +} + +// validatePrivilegeDropSuccess validates that privilege dropping was successful +func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error { + if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil { + return err + } + + return nil +} + +// validatePrivilegeDropReversibility ensures privileges cannot be restored +func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error { + if originalGID != int(targetGID) { + if err := syscall.Setegid(originalGID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID) + } + } + if originalUID != int(targetUID) { + if err := syscall.Seteuid(originalUID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID) + } + } + return nil +} + +// validateCurrentPrivileges validates the current UID and GID match the target +func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error { + currentUID := os.Geteuid() + if currentUID != int(targetUID) { + return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID) + } + + currentGID := os.Getegid() + if currentGID != int(targetGID) { + return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID) + } + + return nil +} + +// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures +func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) { + log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + // TODO: Implement Pty support for executor path + if config.PTY { + config.PTY = false + } + + if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err) + os.Exit(ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err) + } + } + + var execCmd *exec.Cmd + if config.Command == "" { + os.Exit(ExitCodeSuccess) + } + + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) + execCmd.Stdin = os.Stdin + execCmd.Stdout = os.Stdout + execCmd.Stderr = os.Stderr + + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if err := execCmd.Run(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // Normal command exit with non-zero code - not an SSH execution error + log.Tracef("command exited with code %d", exitError.ExitCode()) + os.Exit(exitError.ExitCode()) + } + + // Actual execution failure (command not found, permission denied, etc.) + log.Debugf("command execution failed: %v", err) + os.Exit(ExitCodeShellExecFail) + } + + os.Exit(ExitCodeSuccess) +} + +// validatePrivileges validates that privilege dropping to the target UID/GID is allowed +func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error { + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + // Allow same-user operations (no privilege dropping needed) + if uid == currentUID && gid == currentGID { + return nil + } + + // Only root can drop privileges to other users + if currentUID != 0 { + return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid) + } + + // Root can drop to any user (including root itself) + return nil +} diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go new file mode 100644 index 000000000..0c5108f57 --- /dev/null +++ b/client/ssh/server/executor_unix_test.go @@ -0,0 +1,262 @@ +//go:build unix + +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { + pd := NewPrivilegeDropper() + + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + tests := []struct { + name string + uid uint32 + gid uint32 + wantErr bool + }{ + { + name: "same user - no privilege drop needed", + uid: currentUID, + gid: currentGID, + wantErr: false, + }, + { + name: "non-root to different user should fail", + uid: currentUID + 1, // Use a different UID to ensure it's actually different + gid: currentGID + 1, // Use a different GID to ensure it's actually different + wantErr: currentUID != 0, // Only fail if current user is not root + }, + { + name: "root can drop to any user", + uid: 1000, + gid: 1000, + wantErr: false, + }, + { + name: "root can stay as root", + uid: 0, + gid: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip non-root tests when running as root, and root tests when not root + if tt.name == "non-root to different user should fail" && currentUID == 0 { + t.Skip("Skipping non-root test when running as root") + } + if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 { + t.Skip("Skipping root test when not running as root") + } + + err := pd.validatePrivileges(tt.uid, tt.gid) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000, 1001}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "ls -la", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify the command is calling netbird ssh exec + assert.Contains(t, cmd.Args, "ssh") + assert.Contains(t, cmd.Args, "exec") + assert.Contains(t, cmd.Args, "--uid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--gid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--groups") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "1001") + assert.Contains(t, cmd.Args, "--working-dir") + assert.Contains(t, cmd.Args, "/home/testuser") + assert.Contains(t, cmd.Args, "--shell") + assert.Contains(t, cmd.Args, "/bin/bash") + assert.Contains(t, cmd.Args, "--cmd") + assert.Contains(t, cmd.Args, "ls -la") +} + +func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify no command mode (command is empty so no --cmd flag) + assert.NotContains(t, cmd.Args, "--cmd") + assert.NotContains(t, cmd.Args, "--interactive") +} + +// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping +// This test requires root privileges and will be skipped if not running as root +func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("This test requires root privileges") + } + + // Find a non-root user to test with + testUser, err := findNonRootUser() + if err != nil { + t.Skip("No suitable non-root user found for testing") + } + + // Verify the user actually exists by looking it up again + _, err = user.LookupId(testUser.Uid) + if err != nil { + t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err) + } + + uid64, err := strconv.ParseUint(testUser.Uid, 10, 32) + require.NoError(t, err) + targetUID := uint32(uid64) + + gid64, err := strconv.ParseUint(testUser.Gid, 10, 32) + require.NoError(t, err) + targetGID := uint32(gid64) + + // Test in a child process to avoid affecting the test runner + if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { + pd := NewPrivilegeDropper() + + // This should succeed + err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) + require.NoError(t, err) + + // Verify we are now running as the target user + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + assert.Equal(t, targetUID, currentUID, "UID should match target") + assert.Equal(t, targetGID, currentGID, "GID should match target") + assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root") + assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group") + + return + } + + // Fork a child process to test privilege dropping + cmd := os.Args[0] + args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"} + + env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1") + + execCmd := exec.Command(cmd, args...) + execCmd.Env = env + + err = execCmd.Run() + require.NoError(t, err, "Child process should succeed") +} + +// findNonRootUser finds any non-root user on the system for testing +func findNonRootUser() (*user.User, error) { + // Try common non-root users, but avoid "nobody" on macOS due to negative UID issues + commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"} + + for _, username := range commonUsers { + if u, err := user.Lookup(username); err == nil { + // Parse as signed integer first to handle negative UIDs + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + // Skip negative UIDs (like nobody=-2 on macOS) and root + if uid64 > 0 && uid64 != 0 { + return u, nil + } + } + } + + // If no common users found, try to find any regular user with UID > 100 + // This helps on macOS where regular users start at UID 501 + allUsers := []string{"vma", "user", "test", "admin"} + for _, username := range allUsers { + if u, err := user.Lookup(username); err == nil { + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + if uid64 > 100 { // Regular user + return u, nil + } + } + } + + // If no common users found, return an error + return nil, fmt.Errorf("no suitable non-root user found on this system") +} + +func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { + pd := NewPrivilegeDropper() + currentUID := uint32(os.Geteuid()) + + if currentUID == 0 { + // When running as root, test that root can create commands for any user + config := ExecutorConfig{ + UID: 1000, // Target non-root user + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + assert.NoError(t, err, "Root should be able to create commands for any user") + assert.NotNil(t, cmd) + } else { + // When running as non-root, test that we can't drop to a different user + config := ExecutorConfig{ + UID: 0, // Try to target root + GID: 0, + Groups: []uint32{0}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + _, err := pd.CreateExecutorCommand(context.Background(), config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot drop privileges") + } +} diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go new file mode 100644 index 000000000..d3504e056 --- /dev/null +++ b/client/ssh/server/executor_windows.go @@ -0,0 +1,570 @@ +//go:build windows + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strings" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + ExitCodeSuccess = 0 + ExitCodeLogonFail = 10 + ExitCodeCreateProcessFail = 11 + ExitCodeWorkingDirFail = 12 + ExitCodeShellExecFail = 13 + ExitCodeValidationFail = 14 +) + +type WindowsExecutorConfig struct { + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Interactive bool + Pty bool + PtyWidth int + PtyHeight int +} + +type PrivilegeDropper struct{} + +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +var ( + advapi32 = windows.NewLazyDLL("advapi32.dll") + procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId") +) + +const ( + logon32LogonNetwork = 3 // Network logon - no password required for authenticated users + + // Common error messages + commandFlag = "-Command" + closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials + convertUsernameError = "convert username to UTF16: %w" + convertDomainError = "convert domain to UTF16: %w" +) + +// CreateWindowsExecutorCommand creates a Windows command with privilege dropping. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) { + if config.Username == "" { + return nil, 0, errors.New("username cannot be empty") + } + if config.Shell == "" { + return nil, 0, errors.New("shell cannot be empty") + } + + shell := config.Shell + + var shellArgs []string + if config.Command != "" { + shellArgs = []string{shell, commandFlag, config.Command} + } else { + shellArgs = []string{shell} + } + + log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + + cmd, token, err := pd.CreateWindowsProcessAsUser( + ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) + if err != nil { + return nil, 0, fmt.Errorf("create Windows process as user: %w", err) + } + + return cmd, token, nil +} + +const ( + // StatusSuccess represents successful LSA operation + StatusSuccess = 0 + + // KerbS4ULogonType message type for domain users with Kerberos + KerbS4ULogonType = 12 + // Msv10s4ulogontype message type for local users with MSV1_0 + Msv10s4ulogontype = 12 + + // MicrosoftKerberosNameA is the authentication package name for Kerberos + MicrosoftKerberosNameA = "Kerberos" + // Msv10packagename is the authentication package name for MSV1_0 + Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0" + + NameSamCompatible = 2 + NameUserPrincipal = 8 + NameCanonical = 7 + + maxUPNLen = 1024 +) + +// kerbS4ULogon structure for S4U authentication (domain users) +type kerbS4ULogon struct { + MessageType uint32 + Flags uint32 + ClientUpn unicodeString + ClientRealm unicodeString +} + +// msv10s4ulogon structure for S4U authentication (local users) +type msv10s4ulogon struct { + MessageType uint32 + Flags uint32 + UserPrincipalName unicodeString + DomainName unicodeString +} + +// unicodeString structure +type unicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +// lsaString structure +type lsaString struct { + Length uint16 + MaximumLength uint16 + Buffer *byte +} + +// tokenSource structure +type tokenSource struct { + SourceName [8]byte + SourceIdentifier windows.LUID +} + +// quotaLimits structure +type quotaLimits struct { + PagedPoolLimit uint32 + NonPagedPoolLimit uint32 + MinimumWorkingSetSize uint32 + MaximumWorkingSetSize uint32 + PagefileLimit uint32 + TimeLimit int64 +} + +var ( + secur32 = windows.NewLazyDLL("secur32.dll") + procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess") + procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage") + procLsaLogonUser = secur32.NewProc("LsaLogonUser") + procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer") + procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess") + procTranslateNameW = secur32.NewProc("TranslateNameW") +) + +// newLsaString creates an LsaString from a Go string +func newLsaString(s string) lsaString { + b := append([]byte(s), 0) + return lsaString{ + Length: uint16(len(s)), + MaximumLength: uint16(len(b)), + Buffer: &b[0], + } +} + +// generateS4UUserToken creates a Windows token using S4U authentication +// This is the exact approach OpenSSH for Windows uses for public key authentication +func generateS4UUserToken(username, domain string) (windows.Handle, error) { + userCpn := buildUserCpn(username, domain) + + pd := NewPrivilegeDropper() + isDomainUser := !pd.isLocalUser(domain) + + lsaHandle, err := initializeLsaConnection() + if err != nil { + return 0, err + } + defer cleanupLsaConnection(lsaHandle) + + authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser) + if err != nil { + return 0, err + } + + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + if err != nil { + return 0, err + } + + return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) +} + +// buildUserCpn constructs the user principal name +func buildUserCpn(username, domain string) string { + if domain != "" && domain != "." { + return fmt.Sprintf(`%s\%s`, domain, username) + } + return username +} + +// initializeLsaConnection establishes connection to LSA +func initializeLsaConnection() (windows.Handle, error) { + + processName := newLsaString("NetBird") + var mode uint32 + var lsaHandle windows.Handle + ret, _, _ := procLsaRegisterLogonProcess.Call( + uintptr(unsafe.Pointer(&processName)), + uintptr(unsafe.Pointer(&lsaHandle)), + uintptr(unsafe.Pointer(&mode)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret) + } + + return lsaHandle, nil +} + +// cleanupLsaConnection closes the LSA connection +func cleanupLsaConnection(lsaHandle windows.Handle) { + if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess { + log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret) + } +} + +// lookupAuthenticationPackage finds the correct authentication package +func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) { + var authPackageName lsaString + if isDomainUser { + authPackageName = newLsaString(MicrosoftKerberosNameA) + } else { + authPackageName = newLsaString(Msv10packagename) + } + + var authPackageId uint32 + ret, _, _ := procLsaLookupAuthenticationPackage.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&authPackageName)), + uintptr(unsafe.Pointer(&authPackageId)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret) + } + + return authPackageId, nil +} + +// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format) +func lookupPrincipalName(username, domain string) (string, error) { + samAccountName := fmt.Sprintf(`%s\%s`, domain, username) + samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName) + if err != nil { + return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err) + } + + upnBuf := make([]uint16, maxUPNLen+1) + upnSize := uint32(len(upnBuf)) + + ret, _, _ := procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameUserPrincipal), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + upn := windows.UTF16ToString(upnBuf[:upnSize]) + log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn) + return upn, nil + } + + upnSize = uint32(len(upnBuf)) + ret, _, _ = procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameCanonical), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + canonical := windows.UTF16ToString(upnBuf[:upnSize]) + slashIdx := strings.IndexByte(canonical, '/') + if slashIdx > 0 { + fqdn := canonical[:slashIdx] + upn := fmt.Sprintf("%s@%s", username, fqdn) + log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical) + return upn, nil + } + } + + log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName) + return samAccountName, nil +} + +// prepareS4ULogonStructure creates the appropriate S4U logon structure +func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { + if isDomainUser { + return prepareDomainS4ULogon(username, domain) + } + return prepareLocalS4ULogon(username) +} + +// prepareDomainS4ULogon creates S4U logon structure for domain users +func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { + upn, err := lookupPrincipalName(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("lookup principal name: %w", err) + } + + log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + + upnUtf16, err := windows.UTF16FromString(upn) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + structSize := unsafe.Sizeof(kerbS4ULogon{}) + upnByteSize := len(upnUtf16) * 2 + logonInfoSize := structSize + uintptr(upnByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*kerbS4ULogon)(logonInfo) + s4uLogon.MessageType = KerbS4ULogonType + s4uLogon.Flags = 0 + + upnOffset := structSize + upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset)) + copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16) + + s4uLogon.ClientUpn = unicodeString{ + Length: uint16((len(upnUtf16) - 1) * 2), + MaximumLength: uint16(len(upnUtf16) * 2), + Buffer: upnBuffer, + } + s4uLogon.ClientRealm = unicodeString{} + + return logonInfo, logonInfoSize, nil +} + +// prepareLocalS4ULogon creates S4U logon structure for local users +func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { + log.Debugf("using Msv1_0S4ULogon for local user: %s", username) + + usernameUtf16, err := windows.UTF16FromString(username) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + domainUtf16, err := windows.UTF16FromString(".") + if err != nil { + return nil, 0, fmt.Errorf(convertDomainError, err) + } + + structSize := unsafe.Sizeof(msv10s4ulogon{}) + usernameByteSize := len(usernameUtf16) * 2 + domainByteSize := len(domainUtf16) * 2 + logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*msv10s4ulogon)(logonInfo) + s4uLogon.MessageType = Msv10s4ulogontype + s4uLogon.Flags = 0x0 + + usernameOffset := structSize + usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset)) + copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16) + + s4uLogon.UserPrincipalName = unicodeString{ + Length: uint16((len(usernameUtf16) - 1) * 2), + MaximumLength: uint16(len(usernameUtf16) * 2), + Buffer: usernameBuffer, + } + + domainOffset := usernameOffset + uintptr(usernameByteSize) + domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset)) + copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16) + + s4uLogon.DomainName = unicodeString{ + Length: uint16((len(domainUtf16) - 1) * 2), + MaximumLength: uint16(len(domainUtf16) * 2), + Buffer: domainBuffer, + } + + return logonInfo, logonInfoSize, nil +} + +// performS4ULogon executes the S4U logon operation +func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { + var tokenSource tokenSource + copy(tokenSource.SourceName[:], "netbird") + if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { + log.Debugf("AllocateLocallyUniqueId failed") + } + + originName := newLsaString("netbird") + + var profile uintptr + var profileSize uint32 + var logonId windows.LUID + var token windows.Handle + var quotas quotaLimits + var subStatus int32 + + ret, _, _ := procLsaLogonUser.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&originName)), + logon32LogonNetwork, + uintptr(authPackageId), + uintptr(logonInfo), + logonInfoSize, + 0, + uintptr(unsafe.Pointer(&tokenSource)), + uintptr(unsafe.Pointer(&profile)), + uintptr(unsafe.Pointer(&profileSize)), + uintptr(unsafe.Pointer(&logonId)), + uintptr(unsafe.Pointer(&token)), + uintptr(unsafe.Pointer("as)), + uintptr(unsafe.Pointer(&subStatus)), + ) + + if profile != 0 { + if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { + log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + } + } + + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) + } + + log.Debugf("created S4U %s token for user %s", + map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) + return token, nil +} + +// createToken implements NetBird trust-based authentication using S4U +func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) { + fullUsername := buildUserCpn(username, domain) + + if err := userExists(fullUsername, username, domain); err != nil { + return 0, err + } + + isLocalUser := pd.isLocalUser(domain) + + if isLocalUser { + return pd.authenticateLocalUser(username, fullUsername) + } + return pd.authenticateDomainUser(username, domain, fullUsername) +} + +// userExists checks if the target useVerifier exists on the system +func userExists(fullUsername, username, domain string) error { + if _, err := lookupUser(fullUsername); err != nil { + log.Debugf("User %s not found: %v", fullUsername, err) + if domain != "" && domain != "." { + _, err = lookupUser(username) + } + if err != nil { + return fmt.Errorf("target user %s not found: %w", fullUsername, err) + } + } + return nil +} + +// isLocalUser determines if this is a local user vs domain user +func (pd *PrivilegeDropper) isLocalUser(domain string) bool { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + + return domain == "" || domain == "." || + strings.EqualFold(domain, hostname) +} + +// authenticateLocalUser handles authentication for local users +func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(username, ".") + if err != nil { + return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) + } + return token, nil +} + +// authenticateDomainUser handles authentication for domain users +func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(username, domain) + if err != nil { + return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) + } + log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + return token, nil +} + +// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables). +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) { + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("user authentication: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Debugf("close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir) + if err != nil { + return nil, 0, err + } + + return cmd, primaryToken, nil +} + +// createProcessWithToken creates process with the specified token and executable path. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) { + cmd := exec.CommandContext(ctx, executablePath, args[1:]...) + cmd.Dir = workingDir + + var primaryToken windows.Token + err := windows.DuplicateTokenEx( + sourceToken, + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityIdentification, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err) + } + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Token: syscall.Token(primaryToken), + } + + return cmd, primaryToken, nil +} + +// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) +func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { + return nil, fmt.Errorf("su command not available on Windows") +} diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go new file mode 100644 index 000000000..e22bdfb06 --- /dev/null +++ b/client/ssh/server/jwt_test.go @@ -0,0 +1,629 @@ +package server + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" +) + +func TestJWTEnforcement(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT enforcement tests in short mode") + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + t.Run("blocks_without_jwt", func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + if err != nil { + t.Logf("Detection failed: %v", err) + } + t.Logf("Detected server type: %s", serverType) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + _, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + assert.Error(t, err, "SSH connection should fail when JWT is required but not provided") + }) + + t.Run("allows_when_disabled", func(t *testing.T) { + serverConfigNoJWT := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + serverNoJWT := New(serverConfigNoJWT) + require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config") + serverNoJWT.SetAllowRootLogin(true) + + serverAddrNoJWT := StartTestServer(t, serverNoJWT) + defer require.NoError(t, serverNoJWT.Stop()) + + hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT) + require.NoError(t, err) + portNoJWT, err := strconv.Atoi(portStrNoJWT) + require.NoError(t, err) + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) + assert.False(t, serverType.RequiresJWT()) + + client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT) + require.NoError(t, err) + defer client.Close() + }) + +} + +// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64RawURLEncode(n), + E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func base64RawURLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +// generateValidJWT creates a valid JWT token for testing +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} + +// connectWithNetBirdClient connects to SSH server using NetBird's SSH client +func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) { + t.Helper() + addr := net.JoinHostPort(host, strconv.Itoa(port)) + + ctx := context.Background() + return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{ + InsecureSkipVerify: true, + }) +} + +// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers +func TestJWTDetection(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT detection test in short mode") + } + + jwksServer, _, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) + assert.True(t, serverType.RequiresJWT()) +} + +func TestJWTFailClose(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT fail-close tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + tokenClaims jwt.MapClaims + }{ + { + name: "blocks_token_missing_iat", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }, + }, + { + name: "blocks_token_missing_sub", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_iss", + tokenClaims: jwt.MapClaims{ + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_aud", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_issuer", + tokenClaims: jwt.MapClaims{ + "iss": "wrong-issuer", + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_audience", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": "wrong-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_expired_token", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(-time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + { + name: "blocks_token_exceeding_max_age", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + MaxTokenAge: 3600, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims) + token.Header["kid"] = "test-key-id" + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{ + cryptossh.Password(tokenString), + }, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + assert.Error(t, err, "Authentication should fail (fail-close)") + }) + } +} + +// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types +func TestJWTAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT authentication tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + token string + wantAuthOK bool + setupServer func(*Server) + testOperation func(*testing.T, *cryptossh.Client, string) error + wantOpSuccess bool + }{ + { + name: "allows_shell_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + return session.Shell() + }, + wantOpSuccess: true, + }, + { + name: "rejects_invalid_token", + token: "invalid", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_shell_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_command_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("ls") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "allows_sftp_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + return session.RequestSubsystem("sftp") + }, + wantOpSuccess: true, + }, + { + name: "blocks_sftp_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + err = session.RequestSubsystem("sftp") + if err == nil { + err = session.Wait() + } + return err + }, + wantOpSuccess: false, + }, + { + name: "allows_port_forward_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowRemotePortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: true, + }, + { + name: "blocks_port_forward_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowLocalPortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // TODO: Skip port forwarding tests on Windows - user switching not supported + // These features are tested on Linux/Unix platforms + if runtime.GOOS == "windows" && + (tc.name == "allows_port_forward_with_jwt" || + tc.name == "blocks_port_forward_without_jwt") { + t.Skip("Skipping port forwarding test on Windows - covered by Linux tests") + } + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + if tc.setupServer != nil { + tc.setupServer(server) + } + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + var authMethods []cryptossh.AuthMethod + if tc.token == "valid" { + token := generateValidJWT(t, privateKey, issuer, audience) + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(token), + } + } else if tc.token == "invalid" { + invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid" + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(invalidToken), + } + } + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: authMethods, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if tc.wantAuthOK { + require.NoError(t, err, "JWT authentication should succeed") + } else if err != nil { + t.Logf("Connection failed as expected: %v", err) + return + } + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + err = tc.testOperation(t, conn, serverAddr) + if tc.wantOpSuccess { + require.NoError(t, err, "Operation should succeed") + } else { + assert.Error(t, err, "Operation should fail") + } + }) + } +} diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go new file mode 100644 index 000000000..6138f9296 --- /dev/null +++ b/client/ssh/server/port_forwarding.go @@ -0,0 +1,386 @@ +package server + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// SessionKey uniquely identifies an SSH session +type SessionKey string + +// ConnectionKey uniquely identifies a port forwarding connection within a session +type ConnectionKey string + +// ForwardKey uniquely identifies a port forwarding listener +type ForwardKey string + +// tcpipForwardMsg represents the structure for tcpip-forward SSH requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} + +// SetAllowLocalPortForwarding configures local port forwarding +func (s *Server) SetAllowLocalPortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowLocalPortForwarding = allow +} + +// SetAllowRemotePortForwarding configures remote port forwarding +func (s *Server) SetAllowRemotePortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRemotePortForwarding = allow +} + +// configurePortForwarding sets up port forwarding callbacks +func (s *Server) configurePortForwarding(server *ssh.Server) { + allowLocal := s.allowLocalPortForwarding + allowRemote := s.allowRemotePortForwarding + + server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { + if !allowLocal { + log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { + log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort) + return true + } + + server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + if !allowRemote { + log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { + log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort) + return true + } + + log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote) +} + +// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. +// Returns nil if allowed, error if denied. +func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { + if ctx == nil { + return fmt.Errorf("%s port forwarding denied: no context", forwardType) + } + + username := ctx.User() + remoteAddr := "unknown" + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + + logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port}) + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: false, + FeatureName: forwardType + " port forwarding", + }) + + if !result.Allowed { + return result.Error + } + + logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", + forwardType, result.User.Username, port) + + return nil +} + +// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. +func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + if !s.isRemotePortForwardingAllowed() { + logger.Warnf("tcpip-forward request denied: remote port forwarding disabled") + return false, nil + } + + payload, err := s.parseTcpipForwardRequest(req) + if err != nil { + logger.Errorf("tcpip-forward unmarshal error: %v", err) + return false, nil + } + + if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil { + logger.Warnf("tcpip-forward denied: %v", err) + return false, nil + } + + logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port) + + sshConn, err := s.getSSHConnection(ctx) + if err != nil { + logger.Warnf("tcpip-forward request denied: %v", err) + return false, nil + } + + return s.setupDirectForward(ctx, logger, sshConn, payload) +} + +// cancelTcpipForwardHandler handles cancel-tcpip-forward requests. +func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + var payload tcpipForwardMsg + if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil { + logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err) + return false, nil + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + if s.removeRemoteForwardListener(key) { + logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) + return true, nil + } + + logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port) + return false, nil +} + +// handleRemoteForwardListener handles incoming connections for remote port forwarding. +func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { + log.Debugf("starting remote forward listener handler for %s:%d", host, port) + + defer func() { + log.Debugf("cleaning up remote forward listener for %s:%d", host, port) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } else { + log.Debugf("remote forward listener closed successfully for %s:%d", host, port) + } + }() + + acceptChan := make(chan acceptResult, 1) + + go func() { + for { + conn, err := ln.Accept() + select { + case acceptChan <- acceptResult{conn: conn, err: err}: + if err != nil { + return + } + case <-ctx.Done(): + return + } + } + }() + + for { + select { + case result := <-acceptChan: + if result.err != nil { + log.Debugf("remote forward accept error: %v", result.err) + return + } + go s.handleRemoteForwardConnection(ctx, result.conn, host, port) + case <-ctx.Done(): + log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) + return + } + } +} + +// getRequestLogger creates a logger with user and remote address context +func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { + remoteAddr := "unknown" + username := "unknown" + if ctx != nil { + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + username = ctx.User() + } + return log.WithFields(log.Fields{"user": username, "remote": remoteAddr}) +} + +// isRemotePortForwardingAllowed checks if remote port forwarding is enabled +func (s *Server) isRemotePortForwardingAllowed() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.allowRemotePortForwarding +} + +// parseTcpipForwardRequest parses the SSH request payload +func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { + var payload tcpipForwardMsg + err := cryptossh.Unmarshal(req.Payload, &payload) + return &payload, err +} + +// getSSHConnection extracts SSH connection from context +func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) { + if ctx == nil { + return nil, fmt.Errorf("no context") + } + sshConnValue := ctx.Value(ssh.ContextKeyConn) + if sshConnValue == nil { + return nil, fmt.Errorf("no SSH connection in context") + } + sshConn, ok := sshConnValue.(*cryptossh.ServerConn) + if !ok || sshConn == nil { + return nil, fmt.Errorf("invalid SSH connection in context") + } + return sshConn, nil +} + +// setupDirectForward sets up a direct port forward +func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) { + bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10)) + + ln, err := net.Listen("tcp", bindAddr) + if err != nil { + logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err) + return false, nil + } + + actualPort := payload.Port + if payload.Port == 0 { + tcpAddr := ln.Addr().(*net.TCPAddr) + actualPort = uint32(tcpAddr.Port) + logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + s.storeRemoteForwardListener(key, ln) + + s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) + go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) + + response := make([]byte, 4) + binary.BigEndian.PutUint32(response, actualPort) + + logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort) + return true, response +} + +// acceptResult holds the result of a listener Accept() call +type acceptResult struct { + conn net.Conn + err error +} + +// handleRemoteForwardConnection handles a single remote port forwarding connection +func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { + sessionKey := s.findSessionKeyByContext(ctx) + connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) + logger := log.WithFields(log.Fields{ + "session": sessionKey, + "conn": connID, + }) + + defer func() { + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } + }() + + sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if sshConn == nil { + logger.Debugf("remote forward: no SSH connection in context") + return + } + + remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + return + } + + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) + if err != nil { + logger.Debugf("open forward channel: %v", err) + return + } + + s.proxyForwardConnection(ctx, logger, conn, channel) +} + +// openForwardChannel creates an SSH forwarded-tcpip channel +func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { + logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port) + + payload := struct { + ConnectedAddress string + ConnectedPort uint32 + OriginatorAddress string + OriginatorPort uint32 + }{ + ConnectedAddress: host, + ConnectedPort: port, + OriginatorAddress: remoteAddr.IP.String(), + OriginatorPort: uint32(remoteAddr.Port), + } + + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload)) + if err != nil { + return nil, fmt.Errorf("open SSH channel: %w", err) + } + + go cryptossh.DiscardRequests(reqs) + return channel, nil +} + +// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel +func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(channel, conn); err != nil { + logger.Debugf("copy error (conn->channel): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(conn, channel); err != nil { + logger.Debugf("copy error (channel->conn): %v", err) + } + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + // First copy finished, wait for second copy or context cancellation + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + } + } + + if err := channel.Close(); err != nil { + logger.Debugf("channel close error: %v", err) + } + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } +} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go new file mode 100644 index 000000000..44612532b --- /dev/null +++ b/client/ssh/server/server.go @@ -0,0 +1,712 @@ +package server + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + gojwt "github.com/golang-jwt/jwt/v5" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "golang.org/x/exp/maps" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/version" +) + +// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server +const DefaultSSHPort = 22 + +// InternalSSHPort is the port SSH server listens on and is redirected to +const InternalSSHPort = 22022 + +const ( + errWriteSession = "write session error: %v" + errExitSession = "exit session error: %v" + + msgPrivilegedUserDisabled = "privileged user login is disabled" + + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server + DefaultJWTMaxTokenAge = 5 * 60 +) + +var ( + ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) + ErrUserNotFound = errors.New("user not found") +) + +// PrivilegedUserError represents an error when privileged user login is disabled +type PrivilegedUserError struct { + Username string +} + +func (e *PrivilegedUserError) Error() string { + return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username) +} + +func (e *PrivilegedUserError) Is(target error) bool { + return target == ErrPrivilegedUserDisabled +} + +// UserNotFoundError represents an error when a user cannot be found +type UserNotFoundError struct { + Username string + Cause error +} + +func (e *UserNotFoundError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause) + } + return fmt.Sprintf("user %s not found", e.Username) +} + +func (e *UserNotFoundError) Is(target error) bool { + return target == ErrUserNotFound +} + +func (e *UserNotFoundError) Unwrap() error { + return e.Cause +} + +// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors +func logSessionExitError(logger *log.Entry, err error) { + if err != nil && !errors.Is(err, io.EOF) { + logger.Warnf(errExitSession, err) + } +} + +// safeLogCommand returns a safe representation of the command for logging +func safeLogCommand(cmd []string) string { + if len(cmd) == 0 { + return "" + } + if len(cmd) == 1 { + return cmd[0] + } + return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) +} + +type sshConnectionState struct { + hasActivePortForward bool + username string + remoteAddr string +} + +type authKey string + +func newAuthKey(username string, remoteAddr net.Addr) authKey { + return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) +} + +type Server struct { + sshServer *ssh.Server + mu sync.RWMutex + hostKeyPEM []byte + sessions map[SessionKey]ssh.Session + sessionCancels map[ConnectionKey]context.CancelFunc + sessionJWTUsers map[SessionKey]string + pendingAuthJWT map[authKey]string + + allowLocalPortForwarding bool + allowRemotePortForwarding bool + allowRootLogin bool + allowSFTP bool + jwtEnabled bool + + netstackNet *netstack.Net + + wgAddress wgaddr.Address + + remoteForwardListeners map[ForwardKey]net.Listener + sshConnections map[*cryptossh.ServerConn]*sshConnectionState + + jwtValidator *jwt.Validator + jwtExtractor *jwt.ClaimsExtractor + jwtConfig *JWTConfig + + suSupportsPty bool +} + +type JWTConfig struct { + Issuer string + Audience string + KeysLocation string + MaxTokenAge int64 +} + +// Config contains all SSH server configuration options +type Config struct { + // JWT authentication configuration. If nil, JWT authentication is disabled + JWT *JWTConfig + + // HostKey is the SSH server host key in PEM format + HostKeyPEM []byte +} + +// SessionInfo contains information about an active SSH session +type SessionInfo struct { + Username string + RemoteAddress string + Command string + JWTUsername string +} + +// New creates an SSH server instance with the provided host key and optional JWT configuration +// If jwtConfig is nil, JWT authentication is disabled +func New(config *Config) *Server { + s := &Server{ + mu: sync.RWMutex{}, + hostKeyPEM: config.HostKeyPEM, + sessions: make(map[SessionKey]ssh.Session), + sessionJWTUsers: make(map[SessionKey]string), + pendingAuthJWT: make(map[authKey]string), + remoteForwardListeners: make(map[ForwardKey]net.Listener), + sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), + jwtEnabled: config.JWT != nil, + jwtConfig: config.JWT, + } + + return s +} + +// Start runs the SSH server +func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer != nil { + return errors.New("SSH server is already running") + } + + s.suSupportsPty = s.detectSuPtySupport(ctx) + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + return fmt.Errorf("create listener: %w", err) + } + + sshServer, err := s.createSSHServer(ln.Addr()) + if err != nil { + s.closeListener(ln) + return fmt.Errorf("create SSH server: %w", err) + } + + s.sshServer = sshServer + log.Infof("SSH server started on %s", addrDesc) + + go func() { + if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error: %v", err) + } + }() + return nil +} + +func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { + if s.netstackNet != nil { + ln, err := s.netstackNet.ListenTCPAddrPort(addr) + if err != nil { + return nil, "", fmt.Errorf("listen on netstack: %w", err) + } + return ln, fmt.Sprintf("netstack %s", addr), nil + } + + tcpAddr := net.TCPAddrFromAddrPort(addr) + lc := net.ListenConfig{} + ln, err := lc.Listen(ctx, "tcp", tcpAddr.String()) + if err != nil { + return nil, "", fmt.Errorf("listen: %w", err) + } + return ln, addr.String(), nil +} + +func (s *Server) closeListener(ln net.Listener) { + if ln == nil { + return + } + if err := ln.Close(); err != nil { + log.Debugf("listener close error: %v", err) + } +} + +// Stop closes the SSH server +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer == nil { + return nil + } + + if err := s.sshServer.Close(); err != nil { + log.Debugf("close SSH server: %v", err) + } + + s.sshServer = nil + + maps.Clear(s.sessions) + maps.Clear(s.sessionJWTUsers) + maps.Clear(s.pendingAuthJWT) + maps.Clear(s.sshConnections) + + for _, cancelFunc := range s.sessionCancels { + cancelFunc() + } + maps.Clear(s.sessionCancels) + + for _, listener := range s.remoteForwardListeners { + if err := listener.Close(); err != nil { + log.Debugf("close remote forward listener: %v", err) + } + } + maps.Clear(s.remoteForwardListeners) + + return nil +} + +// GetStatus returns the current status of the SSH server and active sessions +func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { + s.mu.RLock() + defer s.mu.RUnlock() + + enabled = s.sshServer != nil + + for sessionKey, session := range s.sessions { + cmd := "" + if len(session.Command()) > 0 { + cmd = safeLogCommand(session.Command()) + } + + jwtUsername := s.sessionJWTUsers[sessionKey] + + sessions = append(sessions, SessionInfo{ + Username: session.User(), + RemoteAddress: session.RemoteAddr().String(), + Command: cmd, + JWTUsername: jwtUsername, + }) + } + + return enabled, sessions +} + +// SetNetstackNet sets the netstack network for userspace networking +func (s *Server) SetNetstackNet(net *netstack.Net) { + s.mu.Lock() + defer s.mu.Unlock() + s.netstackNet = net +} + +// SetNetworkValidation configures network-based connection filtering +func (s *Server) SetNetworkValidation(addr wgaddr.Address) { + s.mu.Lock() + defer s.mu.Unlock() + s.wgAddress = addr +} + +// ensureJWTValidator initializes the JWT validator and extractor if not already initialized +func (s *Server) ensureJWTValidator() error { + s.mu.RLock() + if s.jwtValidator != nil && s.jwtExtractor != nil { + s.mu.RUnlock() + return nil + } + config := s.jwtConfig + s.mu.RUnlock() + + if config == nil { + return fmt.Errorf("JWT config not set") + } + + log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience) + + validator := jwt.NewValidator( + config.Issuer, + []string{config.Audience}, + config.KeysLocation, + true, + ) + + extractor := jwt.NewClaimsExtractor( + jwt.WithAudience(config.Audience), + ) + + s.mu.Lock() + defer s.mu.Unlock() + + if s.jwtValidator != nil && s.jwtExtractor != nil { + return nil + } + + s.jwtValidator = validator + s.jwtExtractor = extractor + + log.Infof("JWT validator initialized successfully") + return nil +} + +func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) { + s.mu.RLock() + jwtValidator := s.jwtValidator + jwtConfig := s.jwtConfig + s.mu.RUnlock() + + if jwtValidator == nil { + return nil, fmt.Errorf("JWT validator not initialized") + } + + token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString) + if err != nil { + if jwtConfig != nil { + if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil { + return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w", + jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err) + } + } + return nil, fmt.Errorf("validate token: %w", err) + } + + if err := s.checkTokenAge(token, jwtConfig); err != nil { + return nil, err + } + + return token, nil +} + +func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error { + if jwtConfig == nil { + return nil + } + + maxTokenAge := jwtConfig.MaxTokenAge + if maxTokenAge <= 0 { + maxTokenAge = DefaultJWTMaxTokenAge + } + + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token has invalid claims format (user=%s)", userID) + } + + iat, ok := claims["iat"].(float64) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token missing iat claim (user=%s)", userID) + } + + issuedAt := time.Unix(int64(iat), 0) + tokenAge := time.Since(issuedAt) + maxAge := time.Duration(maxTokenAge) * time.Second + if tokenAge > maxAge { + userID := getUserIDFromClaims(claims) + return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge) + } + + return nil +} + +func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) { + s.mu.RLock() + jwtExtractor := s.jwtExtractor + s.mu.RUnlock() + + if jwtExtractor == nil { + userID := extractUserID(token) + return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID) + } + + userAuth, err := jwtExtractor.ToUserAuth(token) + if err != nil { + userID := extractUserID(token) + return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err) + } + + if !s.hasSSHAccess(&userAuth) { + return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId) + } + + return &userAuth, nil +} + +func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool { + return userAuth.UserId != "" +} + +func extractUserID(token *gojwt.Token) string { + if token == nil { + return "unknown" + } + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + return "unknown" + } + return getUserIDFromClaims(claims) +} + +func getUserIDFromClaims(claims gojwt.MapClaims) string { + if sub, ok := claims["sub"].(string); ok && sub != "" { + return sub + } + if userID, ok := claims["user_id"].(string); ok && userID != "" { + return userID + } + if email, ok := claims["email"].(string); ok && email != "" { + return email + } + return "unknown" +} + +func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("decode payload: %w", err) + } + + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("parse claims: %w", err) + } + + return claims, nil +} + +func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { + if err := s.ensureJWTValidator(); err != nil { + log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + token, err := s.validateJWTToken(password) + if err != nil { + log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + userAuth, err := s.extractAndValidateUser(token) + if err != nil { + log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + key := newAuthKey(ctx.User(), ctx.RemoteAddr()) + s.mu.Lock() + s.pendingAuthJWT[key] = userAuth.UserId + s.mu.Unlock() + + log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) + return true +} + +func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + if state, exists := s.sshConnections[sshConn]; exists { + state.hasActivePortForward = true + } else { + s.sshConnections[sshConn] = &sshConnectionState{ + hasActivePortForward: true, + username: username, + remoteAddr: remoteAddr, + } + } +} + +func (s *Server) connectionCloseHandler(conn net.Conn, err error) { + // We can't extract the SSH connection from net.Conn directly + // Connection cleanup will happen during session cleanup or via timeout + log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) +} + +func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { + if ctx == nil { + return "unknown" + } + + // Try to match by SSH connection + sshConn := ctx.Value(ssh.ContextKeyConn) + if sshConn == nil { + return "unknown" + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Look through sessions to find one with matching connection + for sessionKey, session := range s.sessions { + if session.Context().Value(ssh.ContextKeyConn) == sshConn { + return sessionKey + } + } + + // If no session found, this might be during early connection setup + // Return a temporary key that we'll fix up later + if ctx.User() != "" && ctx.RemoteAddr() != nil { + tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) + log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey) + return tempKey + } + + return "unknown" +} + +func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { + s.mu.RLock() + netbirdNetwork := s.wgAddress.Network + localIP := s.wgAddress.IP + s.mu.RUnlock() + + if !netbirdNetwork.IsValid() || !localIP.IsValid() { + return conn + } + + remoteAddr := conn.RemoteAddr() + tcpAddr, ok := remoteAddr.(*net.TCPAddr) + if !ok { + log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr) + return nil + } + + remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP) + if !ok { + log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) + return nil + } + + // Block connections from our own IP (prevent local apps from connecting to ourselves) + if remoteIP == localIP { + log.Warnf("SSH connection rejected from own IP %s", remoteIP) + return nil + } + + if !netbirdNetwork.Contains(remoteIP) { + log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) + return nil + } + + log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) + return conn +} + +func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { + if err := enableUserSwitching(); err != nil { + log.Warnf("failed to enable user switching: %v", err) + } + + serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion()) + if s.jwtEnabled { + serverVersion += " " + detection.JWTRequiredMarker + } + + server := &ssh.Server{ + Addr: addr.String(), + Handler: s.sessionHandler, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpSubsystemHandler, + }, + HostSigners: []ssh.Signer{}, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": s.directTCPIPHandler, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": s.tcpipForwardHandler, + "cancel-tcpip-forward": s.cancelTcpipForwardHandler, + }, + ConnCallback: s.connectionValidator, + ConnectionFailedCallback: s.connectionCloseHandler, + Version: serverVersion, + } + + if s.jwtEnabled { + server.PasswordHandler = s.passwordHandler + } + + hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM) + if err := server.SetOption(hostKeyPEM); err != nil { + return nil, fmt.Errorf("set host key: %w", err) + } + + s.configurePortForwarding(server) + return server, nil +} + +func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { + s.mu.Lock() + defer s.mu.Unlock() + s.remoteForwardListeners[key] = ln +} + +func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { + s.mu.Lock() + defer s.mu.Unlock() + + ln, exists := s.remoteForwardListeners[key] + if !exists { + return false + } + + delete(s.remoteForwardListeners, key) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } + + return true +} + +func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { + var payload struct { + Host string + Port uint32 + OriginatorAddr string + OriginatorPort uint32 + } + + if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { + if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { + log.Debugf("channel reject error: %v", err) + } + return + } + + s.mu.RLock() + allowLocal := s.allowLocalPortForwarding + s.mu.RUnlock() + + if !allowLocal { + log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) + _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") + return + } + + // Check privilege requirements for the destination port + if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { + log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") + return + } + + log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + + ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) +} diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go new file mode 100644 index 000000000..24e455025 --- /dev/null +++ b/client/ssh/server/server_config_test.go @@ -0,0 +1,394 @@ +package server + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" +) + +func TestServer_RootLoginRestriction(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "root login allowed", + allowRoot: true, + username: "root", + expectError: false, + description: "Root login should succeed when allowed", + }, + { + name: "root login denied", + allowRoot: false, + username: "root", + expectError: true, + description: "Root login should fail when disabled", + }, + { + name: "regular user login always allowed", + allowRoot: false, + username: "testuser", + expectError: false, + description: "Regular user login should work regardless of root setting", + }, + } + + // Add Windows Administrator tests if on Windows + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "Administrator login allowed", + allowRoot: true, + username: "Administrator", + expectError: false, + description: "Administrator login should succeed when allowed", + }, + { + name: "Administrator login denied", + allowRoot: false, + username: "Administrator", + expectError: true, + description: "Administrator login should fail when disabled", + }, + { + name: "administrator login denied (lowercase)", + allowRoot: false, + username: "administrator", + expectError: true, + description: "administrator login should fail when disabled (case insensitive)", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to test root access controls + // Set up mock users based on platform + mockUsers := map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"), + } + + // Add Windows-specific users for Administrator tests + if runtime.GOOS == "windows" { + mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator") + mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator") + } + + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + mockUsers, + nil, + ) + defer cleanup() + + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(tt.allowRoot) + + // Test the userNameLookup method directly + user, err := server.userNameLookup(tt.username) + + if tt.expectError { + assert.Error(t, err, tt.description) + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // Check for appropriate error message based on platform capabilities + errorMsg := err.Error() + // Either privileged user restriction OR user switching limitation + hasPrivilegedError := strings.Contains(errorMsg, "privileged user") + hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported") + assert.True(t, hasPrivilegedError || hasSwitchingError, + "Expected privileged user or user switching error, got: %s", errorMsg) + } + } else { + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // For privileged users, we expect either success or a different error + // (like user not found), but not the "login disabled" error + if err != nil { + assert.NotContains(t, err.Error(), "privileged user login is disabled") + } + } else { + // For regular users, lookup should generally succeed or fall back gracefully + // Note: may return current user as fallback + assert.NotNil(t, user) + } + } + }) + } +} + +func TestServer_PortForwardingRestriction(t *testing.T) { + // Test that the port forwarding callbacks properly respect configuration flags + // This is a unit test of the callback logic, not a full integration test + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowLocalForwarding bool + allowRemoteForwarding bool + description string + }{ + { + name: "all forwarding allowed", + allowLocalForwarding: true, + allowRemoteForwarding: true, + description: "Both local and remote forwarding should be allowed", + }, + { + name: "local forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: true, + description: "Local forwarding should be denied when disabled", + }, + { + name: "remote forwarding disabled", + allowLocalForwarding: true, + allowRemoteForwarding: false, + description: "Remote forwarding should be denied when disabled", + }, + { + name: "all forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: false, + description: "Both forwarding types should be denied when disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) + server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) + + // We need to access the internal configuration to simulate the callback tests + // Since the callbacks are created inside the Start method, we'll test the logic directly + + // Test the configuration values are set correctly + server.mu.RLock() + allowLocal := server.allowLocalPortForwarding + allowRemote := server.allowRemotePortForwarding + server.mu.RUnlock() + + assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly") + assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly") + + // Simulate the callback logic + localResult := allowLocal // This would be the callback return value + remoteResult := allowRemote // This would be the callback return value + + assert.Equal(t, tt.allowLocalForwarding, localResult, + "Local port forwarding callback should return correct value") + assert.Equal(t, tt.allowRemoteForwarding, remoteResult, + "Remote port forwarding callback should return correct value") + }) + } +} + +func TestServer_PortConflictHandling(t *testing.T) { + // Test that multiple sessions requesting the same local port are handled naturally by the OS + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Get a free port for testing + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + testPort := ln.Addr().(*net.TCPAddr).Port + err = ln.Close() + require.NoError(t, err) + + // Connect first client + ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel1() + + client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client1.Close() + assert.NoError(t, err) + }() + + // Connect second client + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client2.Close() + assert.NoError(t, err) + }() + + // First client binds to the test port + localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort) + remoteAddr := "127.0.0.1:80" + + // Start first client's port forwarding + done1 := make(chan error, 1) + go func() { + // This should succeed and hold the port + err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr) + done1 <- err + }() + + // Give first client time to bind + time.Sleep(200 * time.Millisecond) + + // Second client tries to bind to same port + localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort) + + shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer shortCancel() + + err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr) + // Second client should fail due to "address already in use" + assert.Error(t, err, "Second client should fail to bind to same port") + if err != nil { + // The error should indicate the address is already in use + errMsg := strings.ToLower(err.Error()) + if runtime.GOOS == "windows" { + assert.Contains(t, errMsg, "only one usage of each socket address", + "Error should indicate port conflict") + } else { + assert.Contains(t, errMsg, "address already in use", + "Error should indicate port conflict") + } + } + + // Cancel first client's context and wait for it to finish + cancel1() + select { + case err1 := <-done1: + // Should get context cancelled or deadline exceeded + assert.Error(t, err1, "First client should exit when context cancelled") + case <-time.After(2 * time.Second): + t.Error("First client did not exit within timeout") + } +} + +func TestServer_IsPrivilegedUser(t *testing.T) { + + tests := []struct { + username string + expected bool + description string + }{ + { + username: "root", + expected: true, + description: "root should be considered privileged", + }, + { + username: "regular", + expected: false, + description: "regular user should not be privileged", + }, + { + username: "", + expected: false, + description: "empty username should not be privileged", + }, + } + + // Add Windows-specific tests + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: true, + description: "Administrator should be considered privileged on Windows", + }, + { + username: "administrator", + expected: true, + description: "administrator should be considered privileged on Windows (case insensitive)", + }, + }...) + } else { + // On non-Windows systems, Administrator should not be privileged + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: false, + description: "Administrator should not be privileged on non-Windows systems", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go new file mode 100644 index 000000000..661068539 --- /dev/null +++ b/client/ssh/server/server_test.go @@ -0,0 +1,441 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os/user" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +func TestServer_StartStop(t *testing.T) { + key, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: key, + JWT: nil, + } + server := New(serverConfig) + + err = server.Stop() + assert.NoError(t, err) +} + +func TestSSHServerIntegration(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server with random port + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server in background + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + // Get a free port + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key for verification + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Create SSH client config + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Connect to SSH server + client, err := cryptossh.Dial("tcp", serverAddr, config) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("close client: %v", err) + } + }() + + // Test creating a session + session, err := client.NewSession() + require.NoError(t, err) + defer func() { + if err := session.Close(); err != nil { + t.Logf("close session: %v", err) + } + }() + + // Note: Since we don't have a real shell environment in tests, + // we can't test actual command execution, but we can verify + // the connection and authentication work + t.Log("SSH connection and authentication successful") +} + +func TestSSHServerMultipleConnections(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Test multiple concurrent connections + const numConnections = 5 + results := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + go func(id int) { + client, err := cryptossh.Dial("tcp", serverAddr, config) + if err != nil { + results <- fmt.Errorf("connection %d failed: %w", id, err) + return + } + defer func() { + _ = client.Close() // Ignore error in test goroutine + }() + + session, err := client.NewSession() + if err != nil { + results <- fmt.Errorf("session %d failed: %w", id, err) + return + } + defer func() { + _ = session.Close() // Ignore error in test goroutine + }() + + results <- nil + }(i) + } + + // Wait for all connections to complete + for i := 0; i < numConnections; i++ { + select { + case err := <-results: + assert.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatalf("Connection %d timed out", i) + } + } +} + +func TestSSHServerNoAuthMode(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate a client private key for SSH protocol (server doesn't check it) + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Try to connect with client key + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(clientSigner), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // This should succeed in no-auth mode (server doesn't verify keys) + conn, err := cryptossh.Dial("tcp", serverAddr, config) + assert.NoError(t, err, "Connection should succeed in no-auth mode") + if conn != nil { + assert.NoError(t, conn.Close()) + } +} + +func TestSSHServerStartStopCycle(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + serverAddr := "127.0.0.1:0" + + // Test multiple start/stop cycles + for i := 0; i < 3; i++ { + t.Logf("Start/stop cycle %d", i+1) + + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case <-started: + case err := <-errChan: + t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err) + case <-time.After(5 * time.Second): + t.Fatalf("Cycle %d: Server start timeout", i+1) + } + + err = server.Stop() + require.NoError(t, err, "Cycle %d: Stop should succeed", i+1) + } +} + +func TestSSHServer_WindowsShellHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping Windows shell test in short mode") + } + + server := &Server{} + + if runtime.GOOS == "windows" { + // Test Windows cmd.exe shell behavior + args := server.getShellCommandArgs("cmd.exe", "echo test") + assert.Equal(t, "cmd.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + + // Test PowerShell behavior + args = server.getShellCommandArgs("powershell.exe", "echo test") + assert.Equal(t, "powershell.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + } else { + // Test Unix shell behavior + args := server.getShellCommandArgs("/bin/sh", "echo test") + assert.Equal(t, "/bin/sh", args[0]) + assert.Equal(t, "-l", args[1]) + assert.Equal(t, "-c", args[2]) + assert.Equal(t, "echo test", args[3]) + } +} + +func TestSSHServer_PortForwardingConfiguration(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig1 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server1 := New(serverConfig1) + + serverConfig2 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server2 := New(serverConfig2) + + assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security") + assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security") + + server2.SetAllowLocalPortForwarding(true) + server2.SetAllowRemotePortForwarding(true) + + assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set") + assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set") +} diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go new file mode 100644 index 000000000..4e6d72098 --- /dev/null +++ b/client/ssh/server/session_handlers.go @@ -0,0 +1,168 @@ +package server + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// sessionHandler handles SSH sessions +func (s *Server) sessionHandler(session ssh.Session) { + sessionKey := s.registerSession(session) + + key := newAuthKey(session.User(), session.RemoteAddr()) + s.mu.Lock() + jwtUsername := s.pendingAuthJWT[key] + if jwtUsername != "" { + s.sessionJWTUsers[sessionKey] = jwtUsername + delete(s.pendingAuthJWT, key) + } + s.mu.Unlock() + + logger := log.WithField("session", sessionKey) + if jwtUsername != "" { + logger = logger.WithField("jwt_user", jwtUsername) + logger.Infof("SSH session started (JWT user: %s)", jwtUsername) + } else { + logger.Infof("SSH session started") + } + sessionStart := time.Now() + + defer s.unregisterSession(sessionKey, session) + defer func() { + duration := time.Since(sessionStart).Round(time.Millisecond) + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Warnf("close session after %v: %v", duration, err) + } + logger.Infof("SSH session closed after %v", duration) + }() + + privilegeResult, err := s.userPrivilegeCheck(session.User()) + if err != nil { + s.handlePrivError(logger, session, err) + return + } + + ptyReq, winCh, isPty := session.Pty() + hasCommand := len(session.Command()) > 0 + + switch { + case isPty && hasCommand: + // ssh -t - Pty command execution + s.handleCommand(logger, session, privilegeResult, winCh) + case isPty: + // ssh - Pty interactive session (login) + s.handlePty(logger, session, privilegeResult, ptyReq, winCh) + case hasCommand: + // ssh - non-Pty command execution + s.handleCommand(logger, session, privilegeResult, nil) + default: + s.rejectInvalidSession(logger, session) + } +} + +func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { + if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) +} + +func (s *Server) registerSession(session ssh.Session) SessionKey { + sessionID := session.Context().Value(ssh.ContextKeySessionID) + if sessionID == nil { + sessionID = fmt.Sprintf("%p", session) + } + + // Create a short 4-byte identifier from the full session ID + hasher := sha256.New() + hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) + hash := hasher.Sum(nil) + shortID := hex.EncodeToString(hash[:4]) + + remoteAddr := session.RemoteAddr().String() + username := session.User() + sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) + + s.mu.Lock() + s.sessions[sessionKey] = session + s.mu.Unlock() + + return sessionKey +} + +func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { + s.mu.Lock() + delete(s.sessions, sessionKey) + delete(s.sessionJWTUsers, sessionKey) + + // Cancel all port forwarding connections for this session + var connectionsToCancel []ConnectionKey + for key := range s.sessionCancels { + if strings.HasPrefix(string(key), string(sessionKey)+"-") { + connectionsToCancel = append(connectionsToCancel, key) + } + } + + for _, key := range connectionsToCancel { + if cancelFunc, exists := s.sessionCancels[key]; exists { + log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key) + cancelFunc() + delete(s.sessionCancels, key) + } + } + + if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil { + if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok { + delete(s.sshConnections, sshConn) + } + } + + s.mu.Unlock() +} + +func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { + logger.Warnf("user privilege check failed: %v", err) + + errorMsg := s.buildUserLookupErrorMessage(err) + + if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if exitErr := session.Exit(1); exitErr != nil { + logSessionExitError(logger, exitErr) + } +} + +// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type +func (s *Server) buildUserLookupErrorMessage(err error) string { + var privilegedErr *PrivilegedUserError + + switch { + case errors.As(err, &privilegedErr): + if privilegedErr.Username == "root" { + return "root login is disabled on this SSH server\n" + } + return "privileged user access is disabled on this SSH server\n" + + case errors.Is(err, ErrPrivilegeRequired): + return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n" + + case errors.Is(err, ErrPrivilegedUserSwitch): + return "Cannot switch to privileged user - current user lacks required privileges\n" + + default: + return "User authentication failed\n" + } +} diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go new file mode 100644 index 000000000..c35e4da0b --- /dev/null +++ b/client/ssh/server/session_handlers_js.go @@ -0,0 +1,22 @@ +//go:build js + +package server + +import ( + "fmt" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handlePty is not supported on JS/WASM +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { + errorMsg := "PTY sessions are not supported on WASM/JS platform\n" + if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go new file mode 100644 index 000000000..c2b9f552b --- /dev/null +++ b/client/ssh/server/sftp.go @@ -0,0 +1,81 @@ +package server + +import ( + "fmt" + "io" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" +) + +// SetAllowSFTP enables or disables SFTP support +func (s *Server) SetAllowSFTP(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowSFTP = allow +} + +// sftpSubsystemHandler handles SFTP subsystem requests +func (s *Server) sftpSubsystemHandler(sess ssh.Session) { + s.mu.RLock() + allowSFTP := s.allowSFTP + s.mu.RUnlock() + + if !allowSFTP { + log.Debugf("SFTP subsystem request denied: SFTP disabled") + if err := sess.Exit(1); err != nil { + log.Debugf("SFTP session exit failed: %v", err) + } + return + } + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: sess.User(), + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSFTP, + }) + + if !result.Allowed { + log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) + if err := sess.Exit(1); err != nil { + log.Debugf("exit SFTP session: %v", err) + } + return + } + + log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username) + + if !result.RequiresUserSwitching { + if err := s.executeSftpDirect(sess); err != nil { + log.Errorf("SFTP direct execution: %v", err) + } + return + } + + if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { + log.Errorf("SFTP privilege drop execution: %v", err) + } +} + +// executeSftpDirect executes SFTP directly without privilege dropping +func (s *Server) executeSftpDirect(sess ssh.Session) error { + log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User()) + + sftpServer, err := sftp.NewServer(sess) + if err != nil { + return fmt.Errorf("SFTP server creation: %w", err) + } + + defer func() { + if err := sftpServer.Close(); err != nil { + log.Debugf("failed to close sftp server: %v", err) + } + }() + + if err := sftpServer.Serve(); err != nil && err != io.EOF { + return fmt.Errorf("serve: %w", err) + } + + return nil +} diff --git a/client/ssh/server/sftp_js.go b/client/ssh/server/sftp_js.go new file mode 100644 index 000000000..3b27aeff4 --- /dev/null +++ b/client/ssh/server/sftp_js.go @@ -0,0 +1,12 @@ +//go:build js + +package server + +import ( + "os/user" +) + +// parseUserCredentials is not supported on JS/WASM +func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) { + return 0, 0, nil, errNotSupported +} diff --git a/client/ssh/server/sftp_test.go b/client/ssh/server/sftp_test.go new file mode 100644 index 000000000..32a3643e4 --- /dev/null +++ b/client/ssh/server/sftp_test.go @@ -0,0 +1,228 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os" + "os/user" + "testing" + "time" + + "github.com/pkg/sftp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" +) + +func TestSSHServer_SFTPSubsystem(t *testing.T) { + // Skip SFTP test when running as root due to protocol issues in some environments + if os.Geteuid() == 0 { + t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues") + } + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP enabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(true) + server.SetAllowRootLogin(true) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Create SFTP client + sftpClient, err := sftp.NewClient(conn) + require.NoError(t, err, "SFTP client creation should succeed") + defer func() { + if err := sftpClient.Close(); err != nil { + t.Logf("SFTP client close error: %v", err) + } + }() + + // Test basic SFTP operations + workingDir, err := sftpClient.Getwd() + assert.NoError(t, err, "Should be able to get working directory") + assert.NotEmpty(t, workingDir, "Working directory should not be empty") + + // Test directory listing + files, err := sftpClient.ReadDir(".") + assert.NoError(t, err, "Should be able to list current directory") + assert.NotNil(t, files, "File list should not be nil") +} + +func TestSSHServer_SFTPDisabled(t *testing.T) { + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP disabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(false) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Try to create SFTP client - should fail when SFTP is disabled + _, err = sftp.NewClient(conn) + assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled") +} diff --git a/client/ssh/server/sftp_unix.go b/client/ssh/server/sftp_unix.go new file mode 100644 index 000000000..44202bead --- /dev/null +++ b/client/ssh/server/sftp_unix.go @@ -0,0 +1,71 @@ +//go:build !windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + uid, gid, groups, err := s.parseUserCredentials(targetUser) + if err != nil { + return fmt.Errorf("parse user credentials: %w", err) + } + + sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir) + if err != nil { + return fmt.Errorf("create executor: %w", err) + } + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid) + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting SFTP executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("SFTP process exited with code %d", exitError.ExitCode()) + return nil + } + return fmt.Errorf("exec: %w", err) + } + + return nil +} + +// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping +func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, err + } + + args := []string{ + "ssh", "sftp", + "--uid", strconv.FormatUint(uint64(uid), 10), + "--gid", strconv.FormatUint(uint64(gid), 10), + "--working-dir", workingDir, + } + + for _, group := range groups { + args = append(args, "--groups", strconv.FormatUint(uint64(group), 10)) + } + + log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args) + return exec.CommandContext(sess.Context(), netbirdPath, args...), nil +} diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go new file mode 100644 index 000000000..dc532b9e7 --- /dev/null +++ b/client/ssh/server/sftp_windows.go @@ -0,0 +1,91 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// createSftpCommand creates a Windows SFTP command with user switching. +// The caller must close the returned token handle after starting the process. +func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) { + username, domain := s.parseUsername(targetUser.Username) + + netbirdPath, err := os.Executable() + if err != nil { + return nil, 0, fmt.Errorf("get netbird executable path: %w", err) + } + + args := []string{ + "ssh", "sftp", + "--working-dir", targetUser.HomeDir, + "--windows-username", username, + "--windows-domain", domain, + } + + pd := NewPrivilegeDropper() + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("create token: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Warnf("failed to close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir) + if err != nil { + return nil, 0, fmt.Errorf("create SFTP command: %w", err) + } + + log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username) + return cmd, primaryToken, nil +} + +// executeSftpCommand executes a Windows SFTP command with proper I/O handling +func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error { + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting sftp executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("sftp process exited with code %d", exitError.ExitCode()) + return nil + } + + return fmt.Errorf("exec sftp: %w", err) + } + + return nil +} + +// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + sftpCmd, token, err := s.createSftpCommand(targetUser, sess) + if err != nil { + return fmt.Errorf("create sftp: %w", err) + } + return s.executeSftpCommand(sess, sftpCmd, token) +} diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go new file mode 100644 index 000000000..fea9d2910 --- /dev/null +++ b/client/ssh/server/shell.go @@ -0,0 +1,180 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "os" + "os/exec" + "os/user" + "runtime" + "strconv" + "strings" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +const ( + defaultUnixShell = "/bin/sh" + + pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name + powershellExe = "powershell.exe" +) + +// getUserShell returns the appropriate shell for the given user ID +// Handles all platform-specific logic and fallbacks consistently +func getUserShell(userID string) string { + switch runtime.GOOS { + case "windows": + return getWindowsUserShell() + default: + return getUnixUserShell(userID) + } +} + +// getWindowsUserShell returns the best shell for Windows users. +// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection +// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters. +// PowerShell provides safer argument handling and is available on all modern Windows systems. +// Order: pwsh.exe -> powershell.exe +func getWindowsUserShell() string { + if path, err := exec.LookPath(pwshExe); err == nil { + return path + } + if path, err := exec.LookPath(powershellExe); err == nil { + return path + } + + return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` +} + +// getUnixUserShell returns the shell for Unix-like systems +func getUnixUserShell(userID string) string { + shell := getShellFromPasswd(userID) + if shell != "" { + return shell + } + + if shell := os.Getenv("SHELL"); shell != "" { + return shell + } + + return defaultUnixShell +} + +// getShellFromPasswd reads the shell from /etc/passwd for the given user ID +func getShellFromPasswd(userID string) string { + file, err := os.Open("/etc/passwd") + if err != nil { + return "" + } + defer func() { + if err := file.Close(); err != nil { + log.Warnf("close /etc/passwd file: %v", err) + } + }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Split(line, ":") + if len(fields) < 7 { + continue + } + + // field 2 is UID + if fields[2] == userID { + shell := strings.TrimSpace(fields[6]) + return shell + } + } + + if err := scanner.Err(); err != nil { + log.Warnf("error reading /etc/passwd: %v", err) + } + + return "" +} + +// prepareUserEnv prepares environment variables for user execution +func prepareUserEnv(user *user.User, shell string) []string { + pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" + if runtime.GOOS == "windows" { + pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0` + } + + return []string{ + fmt.Sprint("SHELL=" + shell), + fmt.Sprint("USER=" + user.Username), + fmt.Sprint("LOGNAME=" + user.Username), + fmt.Sprint("HOME=" + user.HomeDir), + "PATH=" + pathValue, + } +} + +// acceptEnv checks if environment variable from SSH client should be accepted +// This is a whitelist of variables that SSH clients can send to the server +func acceptEnv(envVar string) bool { + varName := envVar + if idx := strings.Index(envVar, "="); idx != -1 { + varName = envVar[:idx] + } + + exactMatches := []string{ + "LANG", + "LANGUAGE", + "TERM", + "COLORTERM", + "EDITOR", + "VISUAL", + "PAGER", + "LESS", + "LESSCHARSET", + "TZ", + } + + prefixMatches := []string{ + "LC_", + } + + for _, exact := range exactMatches { + if varName == exact { + return true + } + } + + for _, prefix := range prefixMatches { + if strings.HasPrefix(varName, prefix) { + return true + } + } + + return false +} + +// prepareSSHEnv prepares SSH protocol-specific environment variables +// These variables provide information about the SSH connection itself +func prepareSSHEnv(session ssh.Session) []string { + remoteAddr := session.RemoteAddr() + localAddr := session.LocalAddr() + + remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String()) + if err != nil { + remoteHost = remoteAddr.String() + remotePort = "0" + } + + localHost, localPort, err := net.SplitHostPort(localAddr.String()) + if err != nil { + localHost = localAddr.String() + localPort = strconv.Itoa(InternalSSHPort) + } + + return []string{ + // SSH_CLIENT format: "client_ip client_port server_port" + fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort), + // SSH_CONNECTION format: "client_ip client_port server_ip server_port" + fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort), + } +} diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go new file mode 100644 index 000000000..20930c721 --- /dev/null +++ b/client/ssh/server/test.go @@ -0,0 +1,45 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" +) + +func StartTestServer(t *testing.T, server *Server) string { + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort := netip.MustParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + return actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + return "" +} diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go new file mode 100644 index 000000000..799882cbb --- /dev/null +++ b/client/ssh/server/user_utils.go @@ -0,0 +1,411 @@ +package server + +import ( + "errors" + "fmt" + "os" + "os/user" + "runtime" + "strings" + + log "github.com/sirupsen/logrus" +) + +var ( + ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges") + ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges") +) + +// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.) +func isPlatformUnix() bool { + return getCurrentOS() != "windows" +} + +// Dependency injection variables for testing - allows mocking dynamic runtime checks +var ( + getCurrentUser = user.Current + lookupUser = user.Lookup + getCurrentOS = func() string { return runtime.GOOS } + getIsProcessPrivileged = isCurrentProcessPrivileged + + getEuid = os.Geteuid +) + +const ( + // FeatureSSHLogin represents SSH login operations for privilege checking + FeatureSSHLogin = "SSH login" + // FeatureSFTP represents SFTP operations for privilege checking + FeatureSFTP = "SFTP" +) + +// PrivilegeCheckRequest represents a privilege check request +type PrivilegeCheckRequest struct { + // Username being requested (empty = current user) + RequestedUsername string + FeatureSupportsUserSwitch bool // Does this feature/operation support user switching? + FeatureName string +} + +// PrivilegeCheckResult represents the result of a privilege check +type PrivilegeCheckResult struct { + // Allowed indicates whether the privilege check passed + Allowed bool + // User is the effective user to use for the operation (nil if not allowed) + User *user.User + // Error contains the reason for denial (nil if allowed) + Error error + // UsedFallback indicates we fell back to current user instead of requested user. + // This happens on Unix when running as an unprivileged user (e.g., in containers) + // where there's no point in user switching since we lack privileges anyway. + // When true, all privilege checks have already been performed and no additional + // privilege dropping or root checks are needed - the current user is the target. + UsedFallback bool + // RequiresUserSwitching indicates whether user switching will actually occur + // (false for fallback cases where no actual switching happens) + RequiresUserSwitching bool +} + +// CheckPrivileges performs comprehensive privilege checking for all SSH features. +// This is the single source of truth for privilege decisions across the SSH server. +func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult { + context, err := s.buildPrivilegeCheckContext(req.FeatureName) + if err != nil { + return PrivilegeCheckResult{Allowed: false, Error: err} + } + + // Handle empty username case - but still check root access controls + if req.RequestedUsername == "" { + if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: context.currentUser.Username}, + } + } + return PrivilegeCheckResult{ + Allowed: true, + User: context.currentUser, + RequiresUserSwitching: false, + } + } + + return s.checkUserRequest(context, req) +} + +// buildPrivilegeCheckContext gathers all the context needed for privilege checking +func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) { + currentUser, err := getCurrentUser() + if err != nil { + return nil, fmt.Errorf("get current user for %s: %w", featureName, err) + } + + s.mu.RLock() + allowRoot := s.allowRootLogin + s.mu.RUnlock() + + return &privilegeCheckContext{ + currentUser: currentUser, + currentUserPrivileged: getIsProcessPrivileged(), + allowRoot: allowRoot, + }, nil +} + +// checkUserRequest handles normal privilege checking flow for specific usernames +func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult { + if !ctx.currentUserPrivileged && isPlatformUnix() { + log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)", + ctx.currentUser.Username, req.FeatureName, req.RequestedUsername) + return PrivilegeCheckResult{ + Allowed: true, + User: ctx.currentUser, + UsedFallback: true, + RequiresUserSwitching: false, + } + } + + resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername) + if err != nil { + // Calculate if user switching would be required even if lookup failed + needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username) + return PrivilegeCheckResult{ + Allowed: false, + Error: err, + RequiresUserSwitching: needsUserSwitching, + } + } + + needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser) + + if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: resolvedUser.Username}, + RequiresUserSwitching: needsUserSwitching, + } + } + + if needsUserSwitching && !req.FeatureSupportsUserSwitch { + return PrivilegeCheckResult{ + Allowed: false, + Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName), + RequiresUserSwitching: needsUserSwitching, + } + } + + return PrivilegeCheckResult{ + Allowed: true, + User: resolvedUser, + RequiresUserSwitching: needsUserSwitching, + } +} + +// resolveRequestedUser resolves a username to its canonical user identity +func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) { + if requestedUsername == "" { + return getCurrentUser() + } + + if err := validateUsername(requestedUsername); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err) + } + + u, err := lookupUser(requestedUsername) + if err != nil { + return nil, &UserNotFoundError{Username: requestedUsername, Cause: err} + } + return u, nil +} + +// isSameResolvedUser compares two resolved user identities +func isSameResolvedUser(user1, user2 *user.User) bool { + if user1 == nil || user2 == nil { + return user1 == user2 + } + return user1.Uid == user2.Uid +} + +// privilegeCheckContext holds all context needed for privilege checking +type privilegeCheckContext struct { + currentUser *user.User + currentUserPrivileged bool + allowRoot bool +} + +// isSameUser checks if two usernames refer to the same user +// SECURITY: This function must be conservative - it should only return true +// when we're certain both usernames refer to the exact same user identity +func isSameUser(requestedUsername, currentUsername string) bool { + // Empty requested username means current user + if requestedUsername == "" { + return true + } + + // Exact match (most common case) + if getCurrentOS() == "windows" { + if strings.EqualFold(requestedUsername, currentUsername) { + return true + } + } else { + if requestedUsername == currentUsername { + return true + } + } + + // Windows domain resolution: only allow domain stripping when comparing + // a bare username against the current user's domain-qualified name + if getCurrentOS() == "windows" { + return isWindowsSameUser(requestedUsername, currentUsername) + } + + return false +} + +// isWindowsSameUser handles Windows-specific user comparison with domain logic +func isWindowsSameUser(requestedUsername, currentUsername string) bool { + // Extract domain and username parts + extractParts := func(name string) (domain, user string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(name, `\`); idx != -1 { + return name[:idx], name[idx+1:] + } + // Handle user@domain.com format + if idx := strings.Index(name, "@"); idx != -1 { + return name[idx+1:], name[:idx] + } + // No domain specified - local machine + return "", name + } + + reqDomain, reqUser := extractParts(requestedUsername) + curDomain, curUser := extractParts(currentUsername) + + // Case-insensitive username comparison + if !strings.EqualFold(reqUser, curUser) { + return false + } + + // If requested username has no domain, it refers to local machine user + // Allow this to match the current user regardless of current user's domain + if reqDomain == "" { + return true + } + + // If both have domains, they must match exactly (case-insensitive) + return strings.EqualFold(reqDomain, curDomain) +} + +// SetAllowRootLogin configures root login access +func (s *Server) SetAllowRootLogin(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRootLogin = allow +} + +// userNameLookup performs user lookup with root login permission check +func (s *Server) userNameLookup(username string) (*user.User, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return nil, result.Error + } + + return result.User, nil +} + +// userPrivilegeCheck performs user lookup with full privilege check result +func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return result, result.Error + } + + return result, nil +} + +// isPrivilegedUsername checks if the given username represents a privileged user across platforms. +// On Unix: root +// On Windows: Administrator, SYSTEM (case-insensitive) +// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com" +func isPrivilegedUsername(username string) bool { + if getCurrentOS() != "windows" { + return username == "root" + } + + bareUsername := username + // Handle Windows domain format: DOMAIN\username + if idx := strings.LastIndex(username, `\`); idx != -1 { + bareUsername = username[idx+1:] + } + // Handle email-style format: username@domain.com + if idx := strings.Index(bareUsername, "@"); idx != -1 { + bareUsername = bareUsername[:idx] + } + + return isWindowsPrivilegedUser(bareUsername) +} + +// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account +func isWindowsPrivilegedUser(bareUsername string) bool { + // common privileged usernames (case insensitive) + privilegedNames := []string{ + "administrator", + "admin", + "root", + "system", + "localsystem", + "networkservice", + "localservice", + } + + usernameLower := strings.ToLower(bareUsername) + for _, privilegedName := range privilegedNames { + if usernameLower == privilegedName { + return true + } + } + + // computer accounts (ending with $) are not privileged by themselves + // They only gain privileges through group membership or specific SIDs + + if targetUser, err := lookupUser(bareUsername); err == nil { + return isWindowsPrivilegedSID(targetUser.Uid) + } + + return false +} + +// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account +func isWindowsPrivilegedSID(sid string) bool { + privilegedSIDs := []string{ + "S-1-5-18", // Local System (SYSTEM) + "S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE) + "S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE) + "S-1-5-32-544", // Administrators group (BUILTIN\Administrators) + "S-1-5-500", // Built-in Administrator account (local machine RID 500) + } + + for _, privilegedSID := range privilegedSIDs { + if sid == privilegedSID { + return true + } + } + + // Check for domain administrator accounts (RID 500 in any domain) + // Format: S-1-5-21-domain-domain-domain-500 + // This is reliable as RID 500 is reserved for the domain Administrator account + if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") { + return true + } + + // Check for other well-known privileged RIDs in domain contexts + // RID 512 = Domain Admins group, RID 516 = Domain Controllers group + if strings.HasPrefix(sid, "S-1-5-21-") { + if strings.HasSuffix(sid, "-512") || // Domain Admins group + strings.HasSuffix(sid, "-516") || // Domain Controllers group + strings.HasSuffix(sid, "-519") { // Enterprise Admins group + return true + } + } + + return false +} + +// isCurrentProcessPrivileged checks if the current process is running with elevated privileges. +// On Unix systems, this means running as root (UID 0). +// On Windows, this means running as Administrator or SYSTEM. +func isCurrentProcessPrivileged() bool { + if getCurrentOS() == "windows" { + return isWindowsElevated() + } + return getEuid() == 0 +} + +// isWindowsElevated checks if the current process is running with elevated privileges on Windows +func isWindowsElevated() bool { + currentUser, err := getCurrentUser() + if err != nil { + log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err) + return false + } + + if isWindowsPrivilegedSID(currentUser.Uid) { + log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid) + return true + } + + if isPrivilegedUsername(currentUser.Username) { + log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username) + return true + } + + log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid) + return false +} diff --git a/client/ssh/server/user_utils_js.go b/client/ssh/server/user_utils_js.go new file mode 100644 index 000000000..163b24c6c --- /dev/null +++ b/client/ssh/server/user_utils_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// validateUsername is not supported on JS/WASM +func validateUsername(_ string) error { + return errNotSupported +} diff --git a/client/ssh/server/user_utils_test.go b/client/ssh/server/user_utils_test.go new file mode 100644 index 000000000..637dc10d0 --- /dev/null +++ b/client/ssh/server/user_utils_test.go @@ -0,0 +1,908 @@ +package server + +import ( + "errors" + "os/user" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Test helper functions +func createTestUser(username, uid, gid, homeDir string) *user.User { + return &user.User{ + Uid: uid, + Gid: gid, + Username: username, + Name: username, + HomeDir: homeDir, + } +} + +// Test dependency injection setup - injects platform dependencies to test real logic +func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() { + // Store originals + originalGetCurrentUser := getCurrentUser + originalLookupUser := lookupUser + originalGetCurrentOS := getCurrentOS + originalGetEuid := getEuid + + // Reset caches to ensure clean test state + + // Set test values - inject platform dependencies + getCurrentUser = func() (*user.User, error) { + return currentUser, currentUserErr + } + + lookupUser = func(username string) (*user.User, error) { + if err, exists := lookupErrors[username]; exists { + return nil, err + } + if userObj, exists := lookupUsers[username]; exists { + return userObj, nil + } + return nil, errors.New("user: unknown user " + username) + } + + getCurrentOS = func() string { + return os + } + + getEuid = func() int { + return euid + } + + // Mock privilege detection based on the test user + getIsProcessPrivileged = func() bool { + if currentUser == nil { + return false + } + // Check both username and SID for Windows systems + if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) { + return true + } + return isPrivilegedUsername(currentUser.Username) + } + + // Return cleanup function + return func() { + getCurrentUser = originalGetCurrentUser + lookupUser = originalLookupUser + getCurrentOS = originalGetCurrentOS + getEuid = originalGetEuid + + getIsProcessPrivileged = isCurrentProcessPrivileged + + // Reset caches after test + } +} + +func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + requestedUsername string + featureSupportsUserSwitch bool + allowRoot bool + lookupUsers map[string]*user.User + expectedAllowed bool + expectedRequiresSwitch bool + }{ + { + name: "linux_root_can_switch_to_alice", + os: "linux", + euid: 0, // Root process + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "linux_non_root_fallback_to_current_user", + os: "linux", + euid: 1000, // Non-root process + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to current user (alice) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + { + name: "windows_admin_can_switch_to_alice", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_no_fallback_hard_failure", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, // Different user was requested + }, + // Comprehensive test matrix: non-root linux with different allowRoot settings + { + name: "linux_non_root_request_root_allowRoot_false", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Fallback allows access regardless of root setting + expectedRequiresSwitch: false, // Fallback case, no switching + }, + { + name: "linux_non_root_request_root_allowRoot_true", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to alice (non-privileged process) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + // Windows admin test matrix + { + name: "windows_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Windows user switching should work like Unix + expectedRequiresSwitch: true, + }, + // Windows non-admin test matrix + { + name: "windows_non_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence) + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // SYSTEM can switch to root + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, + }, + + // Feature doesn't support user switching scenarios + { + name: "linux_root_feature_no_user_switching_same_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "root", // Same user + featureSupportsUserSwitch: false, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Same user should work regardless of feature support + expectedRequiresSwitch: false, + }, + { + name: "linux_root_feature_no_user_switching_different_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Empty username (current user) scenarios + { + name: "linux_non_root_current_user_empty_username", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, + }, + { + name: "linux_root_current_user_empty_username_root_not_allowed", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "", // Empty = current user (root) + featureSupportsUserSwitch: true, + allowRoot: false, // Root not allowed + expectedAllowed: false, // Should deny root even when it's current user + expectedRequiresSwitch: false, + }, + + // User not found scenarios + { + name: "linux_root_user_not_found", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "nonexistent", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{}, // No users defined = user not found + expectedAllowed: false, // Should fail due to user not found + expectedRequiresSwitch: true, + }, + + // Windows feature doesn't support user switching + { + name: "windows_admin_feature_no_user_switching_different_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Windows regular user scenarios (non-admin) + { + name: "windows_regular_user_same_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "alice", // Same user + featureSupportsUserSwitch: true, + allowRoot: false, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, // Regular user accessing themselves should work + expectedRequiresSwitch: false, // No switching for same user + }, + { + name: "windows_regular_user_empty_username", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, // No switching for current user + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies to test real logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed) + assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching) + }) + } +} + +func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Fallback mechanism is Unix-specific") + } + + // Create test scenario where fallback should occur + server := &Server{allowRootLogin: true} + + // Mock dependencies to simulate non-privileged user + originalGetCurrentUser := getCurrentUser + originalGetIsProcessPrivileged := getIsProcessPrivileged + + defer func() { + getCurrentUser = originalGetCurrentUser + getIsProcessPrivileged = originalGetIsProcessPrivileged + + }() + + // Set up mocks for fallback scenario + getCurrentUser = func() (*user.User, error) { + return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil + } + getIsProcessPrivileged = func() bool { return false } // Non-privileged + + // Request different user - should fallback + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "alice", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + // Verify fallback occurred + assert.True(t, result.Allowed, "Should allow with fallback") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.Equal(t, "netbird", result.User.Username, "Should return current user") + assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used") + + // Key assertion: When UsedFallback is true, no privilege dropping should be needed + // because all privilege checks have already been performed and we're using current user + t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed", + result.User.Username) +} + +func TestPrivilegedUsernameDetection(t *testing.T) { + tests := []struct { + name string + username string + platform string + privileged bool + }{ + // Unix/Linux tests + {"unix_root", "root", "linux", true}, + {"unix_regular_user", "alice", "linux", false}, + {"unix_root_capital", "Root", "linux", false}, // Case-sensitive + + // Windows tests + {"windows_administrator", "Administrator", "windows", true}, + {"windows_system", "SYSTEM", "windows", true}, + {"windows_admin", "admin", "windows", true}, + {"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive + {"windows_domain_admin", "DOMAIN\\Administrator", "windows", true}, + {"windows_email_admin", "admin@domain.com", "windows", true}, + {"windows_regular_user", "alice", "windows", false}, + {"windows_domain_user", "DOMAIN\\alice", "windows", false}, + {"windows_localsystem", "localsystem", "windows", true}, + {"windows_networkservice", "networkservice", "windows", true}, + {"windows_localservice", "localservice", "windows", true}, + + // Computer accounts (these depend on current user context in real implementation) + {"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged + {"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account + + // Cross-platform + {"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the platform for this test + cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil) + defer cleanup() + + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.privileged, result) + }) + } +} + +func TestWindowsPrivilegedSIDDetection(t *testing.T) { + tests := []struct { + name string + sid string + privileged bool + description string + }{ + // Well-known system accounts + {"system_account", "S-1-5-18", true, "Local System (SYSTEM)"}, + {"local_service", "S-1-5-19", true, "Local Service"}, + {"network_service", "S-1-5-20", true, "Network Service"}, + {"administrators_group", "S-1-5-32-544", true, "Administrators group"}, + {"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"}, + + // Domain accounts + {"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"}, + {"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"}, + {"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"}, + {"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"}, + + // Regular users + {"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"}, + {"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"}, + {"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"}, + + // Groups that are not privileged + {"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"}, + {"power_users", "S-1-5-32-547", false, "Power Users group"}, + + // Invalid SIDs + {"malformed_sid", "S-1-5-invalid", false, "Malformed SID"}, + {"empty_sid", "", false, "Empty SID"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isWindowsPrivilegedSID(tt.sid) + assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid) + }) + } +} + +func TestIsSameUser(t *testing.T) { + tests := []struct { + name string + user1 string + user2 string + os string + expected bool + }{ + // Basic cases + {"same_username", "alice", "alice", "linux", true}, + {"different_username", "alice", "bob", "linux", false}, + + // Linux (no domain processing) + {"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false}, + {"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false}, + {"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true}, + + // Windows (with domain processing) - Note: parameter order is (requested, current, os, expected) + {"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user + {"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user + {"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users + {"windows_case_insensitive", "Alice", "alice", "windows", true}, + {"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up OS mock + cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil) + defer cleanup() + + result := isSameUser(tt.user1, tt.user2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUsernameValidation_Unix(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Unix/POSIX) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + + // Invalid usernames (Unix/POSIX) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames + {"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +func TestUsernameValidation_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Windows) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + {"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this + {"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format + {"valid_email_username", "user@domain.com", false, ""}, // Windows email format + {"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format + + // Invalid usernames (Windows) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"username_with_brackets", "user[name]", true, "invalid characters"}, + {"username_with_colon", "user:name", true, "invalid characters"}, + {"username_with_semicolon", "user;name", true, "invalid characters"}, + {"username_with_equals", "user=name", true, "invalid characters"}, + {"username_with_comma", "user,name", true, "invalid characters"}, + {"username_with_plus", "user+name", true, "invalid characters"}, + {"username_with_asterisk", "user*name", true, "invalid characters"}, + {"username_with_question", "user?name", true, "invalid characters"}, + {"username_with_angles", "user", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_ending_with_period", "user.", true, "cannot end with a period"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +// Test real-world integration scenarios with actual platform capabilities +func TestCheckPrivileges_RealWorldScenarios(t *testing.T) { + tests := []struct { + name string + feature string + featureSupportsUserSwitch bool + requestedUsername string + allowRoot bool + expectedBehaviorPattern string + }{ + {"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"}, + {"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"}, + {"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"}, + {"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"}, + {"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to ensure consistent test behavior across environments + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"), + }, + nil, + ) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: tt.feature, + }) + + switch tt.expectedBehaviorPattern { + case "should_allow_current_user": + assert.True(t, result.Allowed, "Should allow current user access") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + case "should_deny_root": + assert.False(t, result.Allowed, "Should deny root when not allowed") + assert.Contains(t, result.Error.Error(), "root", "Should mention root in error") + case "should_deny_switching": + assert.False(t, result.Allowed, "Should deny when feature doesn't support switching") + assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error") + } + }) + } +} + +// Test with actual platform capabilities - no mocking +func TestCheckPrivileges_ActualPlatform(t *testing.T) { + // This test uses the REAL platform capabilities + server := &Server{allowRootLogin: true} + + // Test current user access - should always work + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "", // Current user + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.True(t, result.Allowed, "Current user should always be allowed") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + assert.NotNil(t, result.User, "Should return current user") + + // Test user switching capability based on actual platform + actualIsPrivileged := isCurrentProcessPrivileged() // REAL check + actualOS := runtime.GOOS // REAL check + + t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v", + actualOS, actualIsPrivileged, actualIsPrivileged) + + // Test requesting different user + result = server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "nonexistentuser", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + switch { + case actualOS == "windows": + // Windows supports user switching but should fail on nonexistent user + assert.False(t, result.Allowed, "Windows should deny nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "not found", + "Should indicate user not found") + case !actualIsPrivileged: + // Non-privileged Unix processes should fallback to current user + assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user") + assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.NotNil(t, result.User, "Should return current user") + default: + // Privileged Unix processes should attempt user lookup + assert.False(t, result.Allowed, "Should fail due to nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "nonexistentuser", + "Should indicate user not found") + } +} + +// Test platform detection logic with dependency injection +func TestPlatformLogic_DependencyInjection(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + expectedIsProcessPrivileged bool + expectedSupportsUserSwitching bool + }{ + { + name: "linux_root_process", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, + }, + { + name: "linux_non_root_process", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + { + name: "windows_admin_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, // Windows supports user switching when privileged + }, + { + name: "windows_regular_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies and test REAL logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil) + defer cleanup() + + // Test the actual functions with injected dependencies + actualIsPrivileged := isCurrentProcessPrivileged() + actualSupportsUserSwitching := actualIsPrivileged + + assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged, + "isCurrentProcessPrivileged() result mismatch") + assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching, + "supportsUserSwitching() result mismatch") + + t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username) + t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v", + actualIsPrivileged, actualSupportsUserSwitching) + }) + } +} + +func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) { + // Test Windows elevated user switching scenarios with simplified privilege logic + tests := []struct { + name string + currentUser *user.User + requestedUsername string + allowRoot bool + expectedAllowed bool + expectedErrorContains string + }{ + { + name: "windows_admin_can_switch_to_alice", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_non_admin_can_try_switch", + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"), + requestedUsername: "bob", + allowRoot: true, + expectedAllowed: true, // Privilege check allows it, OS will reject during execution + }, + { + name: "windows_system_can_switch_to_alice", + currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_admin_root_not_allowed", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "root", + allowRoot: false, + expectedAllowed: false, + expectedErrorContains: "privileged user login is disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test dependencies with Windows OS and specified privileges + lookupUsers := map[string]*user.User{ + tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername), + } + cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed, + "Privilege check result should match expected for %s", tt.name) + + if !tt.expectedAllowed && tt.expectedErrorContains != "" { + assert.NotNil(t, result.Error, "Should have error when not allowed") + assert.Contains(t, result.Error.Error(), tt.expectedErrorContains, + "Error should contain expected message") + } + + if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername { + assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user") + } + }) + } +} diff --git a/client/ssh/server/userswitching_js.go b/client/ssh/server/userswitching_js.go new file mode 100644 index 000000000..333c19259 --- /dev/null +++ b/client/ssh/server/userswitching_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// enableUserSwitching is not supported on JS/WASM +func enableUserSwitching() error { + return errNotSupported +} diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go new file mode 100644 index 000000000..06fefabd7 --- /dev/null +++ b/client/ssh/server/userswitching_unix.go @@ -0,0 +1,233 @@ +//go:build unix + +package server + +import ( + "errors" + "fmt" + "net" + "net/netip" + "os" + "os/exec" + "os/user" + "regexp" + "runtime" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// POSIX portable filename character set regex: [a-zA-Z0-9._-] +// First character cannot be hyphen (POSIX requirement) +var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`) + +// validateUsername validates that a username conforms to POSIX standards with security considerations +func validateUsername(username string) error { + if username == "" { + return errors.New("username cannot be empty") + } + + // POSIX allows up to 256 characters, but practical limit is 32 for compatibility + if len(username) > 32 { + return errors.New("username too long (max 32 characters)") + } + + if !posixUsernameRegex.MatchString(username) { + return errors.New("username contains invalid characters (must match POSIX portable filename character set)") + } + + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + // Warn if username is fully numeric (can cause issues with UID/username ambiguity) + if isFullyNumeric(username) { + log.Warnf("fully numeric username '%s' may cause issues with some commands", username) + } + + return nil +} + +// isFullyNumeric checks if username contains only digits +func isFullyNumeric(username string) bool { + for _, char := range username { + if char < '0' || char > '9' { + return false + } + } + return true +} + +// createPtyLoginCommand creates a Pty command using login for privileged processes +func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("get login command: %w", err) + } + + execCmd := exec.CommandContext(session.Context(), loginPath, args...) + execCmd.Dir = localUser.HomeDir + execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return execCmd, nil +} + +// getLoginCmd returns the login command and args for privileged Pty user switching +func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) { + loginPath, err := exec.LookPath("login") + if err != nil { + return "", nil, fmt.Errorf("login command not available: %w", err) + } + + addrPort, err := netip.ParseAddrPort(remoteAddr.String()) + if err != nil { + return "", nil, fmt.Errorf("parse remote address: %w", err) + } + + switch runtime.GOOS { + case "linux": + // Special handling for Arch Linux without /etc/pam.d/remote + if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { + return loginPath, []string{"-f", username, "-p"}, nil + } + return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil + case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": + return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil + default: + return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS) + } +} + +// fileExists checks if a file exists (helper for login command logic) +func (s *Server) fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +// parseUserCredentials extracts numeric UID, GID, and supplementary groups +func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) { + uid64, err := strconv.ParseUint(localUser.Uid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err) + } + uid := uint32(uid64) + + gid64, err := strconv.ParseUint(localUser.Gid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err) + } + gid := uint32(gid64) + + groups, err := s.getSupplementaryGroups(localUser.Username) + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups = []uint32{gid} + } + + return uid, gid, groups, nil +} + +// getSupplementaryGroups retrieves supplementary group IDs for a user +func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { + u, err := user.Lookup(username) + if err != nil { + return nil, fmt.Errorf("lookup user %s: %w", username, err) + } + + groupIDStrings, err := u.GroupIds() + if err != nil { + return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + } + + groups := make([]uint32, len(groupIDStrings)) + for i, gidStr := range groupIDStrings { + gid64, err := strconv.ParseUint(gidStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + } + groups[i] = uint32(gid64) + } + + return groups, nil +} + +// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. +// Returns the command and a cleanup function (no-op on Unix). +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + if err := validateUsername(localUser.Username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + + uid, gid, groups, err := s.parseUserCredentials(localUser) + if err != nil { + return nil, nil, fmt.Errorf("parse user credentials: %w", err) + } + privilegeDropper := NewPrivilegeDropper() + config := ExecutorConfig{ + UID: uid, + GID: gid, + Groups: groups, + WorkingDir: localUser.HomeDir, + Shell: getUserShell(localUser.Uid), + Command: session.RawCommand(), + PTY: hasPty, + } + + cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config) + return cmd, func() {}, err +} + +// enableUserSwitching is a no-op on Unix systems +func enableUserSwitching() error { + return nil +} + +// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results +func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, errors.New("no user in privilege result") + } + + if privilegeResult.UsedFallback { + return s.createDirectPtyCommand(session, localUser, ptyReq), nil + } + + return s.createPtyLoginCommand(localUser, ptyReq, session) +} + +// createDirectPtyCommand creates a direct Pty command without privilege dropping +func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd { + log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username) + + shell := getUserShell(localUser.Uid) + args := s.getShellCommandArgs(shell, session.RawCommand()) + + cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd.Dir = localUser.HomeDir + cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return cmd +} + +// preparePtyEnv prepares environment variables for Pty execution +func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + env = append(env, fmt.Sprintf("TERM=%s", termType)) + + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go new file mode 100644 index 000000000..5a5f75fa4 --- /dev/null +++ b/client/ssh/server/userswitching_windows.go @@ -0,0 +1,274 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os/exec" + "os/user" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// validateUsername validates Windows usernames according to SAM Account Name rules +func validateUsername(username string) error { + if username == "" { + return fmt.Errorf("username cannot be empty") + } + + usernameToValidate := extractUsernameFromDomain(username) + + if err := validateUsernameLength(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameCharacters(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameFormat(usernameToValidate); err != nil { + return err + } + + return nil +} + +// extractUsernameFromDomain extracts the username part from domain\username or username@domain format +func extractUsernameFromDomain(username string) string { + if idx := strings.LastIndex(username, `\`); idx != -1 { + return username[idx+1:] + } + if idx := strings.Index(username, "@"); idx != -1 { + return username[:idx] + } + return username +} + +// validateUsernameLength checks if username length is within Windows limits +func validateUsernameLength(username string) error { + if len(username) > 20 { + return fmt.Errorf("username too long (max 20 characters for Windows)") + } + return nil +} + +// validateUsernameCharacters checks for invalid characters in Windows usernames +func validateUsernameCharacters(username string) error { + invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'} + for _, char := range username { + for _, invalid := range invalidChars { + if char == invalid { + return fmt.Errorf("username contains invalid characters") + } + } + if char < 32 || char == 127 { + return fmt.Errorf("username contains control characters") + } + } + return nil +} + +// validateUsernameFormat checks for invalid username formats and patterns +func validateUsernameFormat(username string) error { + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + if strings.HasSuffix(username, ".") { + return fmt.Errorf("username cannot end with a period") + } + + return nil +} + +// createExecutorCommand creates a command using Windows executor for privilege dropping. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + username, _ := s.parseUsername(localUser.Username) + if err := validateUsername(username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) + } + + return s.createUserSwitchCommand(localUser, session, hasPty) +} + +// createUserSwitchCommand creates a command with Windows user switching. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { + username, domain := s.parseUsername(localUser.Username) + + shell := getUserShell(localUser.Uid) + + rawCmd := session.RawCommand() + var command string + if rawCmd != "" { + command = rawCmd + } + + config := WindowsExecutorConfig{ + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, + Interactive: interactive || (rawCmd == ""), + } + + dropper := NewPrivilegeDropper() + cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) + if err != nil { + return nil, nil, err + } + + cleanup := func() { + if token != 0 { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + } + } + + return cmd, cleanup, nil +} + +// parseUsername extracts username and domain from a Windows username +func (s *Server) parseUsername(fullUsername string) (username, domain string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(fullUsername, `\`); idx != -1 { + domain = fullUsername[:idx] + username = fullUsername[idx+1:] + return username, domain + } + + // Handle username@domain format + if username, domain, ok := strings.Cut(fullUsername, "@"); ok { + return username, domain + } + + // Local user (no domain) + return fullUsername, "." +} + +// hasPrivilege checks if the current process has a specific privilege +func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return false, fmt.Errorf("lookup privilege value: %w", err) + } + + var returnLength uint32 + err := windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + nil, // null buffer to get size + 0, + &returnLength, + ) + + if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return false, fmt.Errorf("get token information size: %w", err) + } + + buffer := make([]byte, returnLength) + err = windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + &buffer[0], + returnLength, + &returnLength, + ) + if err != nil { + return false, fmt.Errorf("get token information: %w", err) + } + + privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0])) + + // Check if the privilege is present and enabled + for i := uint32(0); i < privileges.PrivilegeCount; i++ { + privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer( + uintptr(unsafe.Pointer(&privileges.Privileges[0])) + + uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}), + )) + if privilege.Luid == luid { + return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil + } + } + + return false, nil +} + +// enablePrivilege enables a specific privilege for the current process token +// This is required because privileges like SeAssignPrimaryTokenPrivilege are present +// but disabled by default, even for the SYSTEM account +func enablePrivilege(token windows.Handle, privilegeName string) error { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err) + } + + privileges := windows.Tokenprivileges{ + PrivilegeCount: 1, + Privileges: [1]windows.LUIDAndAttributes{ + { + Luid: luid, + Attributes: windows.SE_PRIVILEGE_ENABLED, + }, + }, + } + + err := windows.AdjustTokenPrivileges( + windows.Token(token), + false, + &privileges, + 0, + nil, + nil, + ) + if err != nil { + return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err) + } + + hasPriv, err := hasPrivilege(token, privilegeName) + if err != nil { + return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err) + } + if !hasPriv { + return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName) + } + + log.Debugf("Successfully enabled privilege %s for current process", privilegeName) + return nil +} + +// enableUserSwitching enables required privileges for Windows user switching +func enableUserSwitching() error { + process := windows.CurrentProcess() + + var token windows.Token + err := windows.OpenProcessToken( + process, + windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, + &token, + ) + if err != nil { + return fmt.Errorf("open process token: %w", err) + } + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("Failed to close process token: %v", err) + } + }() + + if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil { + return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err) + } + log.Infof("Windows user switching privileges enabled successfully") + return nil +} diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go new file mode 100644 index 000000000..0f3659ffe --- /dev/null +++ b/client/ssh/server/winpty/conpty.go @@ -0,0 +1,487 @@ +//go:build windows + +package winpty + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +var ( + ErrEmptyEnvironment = errors.New("empty environment") +) + +const ( + extendedStartupInfoPresent = 0x00080000 + createUnicodeEnvironment = 0x00000400 + procThreadAttributePseudoConsole = 0x00020016 + + PowerShellCommandFlag = "-Command" + + errCloseInputRead = "close input read handle: %v" + errCloseConPtyCleanup = "close ConPty handle during cleanup" +) + +// PtyConfig holds configuration for Pty execution. +type PtyConfig struct { + Shell string + Command string + Width int + Height int + WorkingDir string +} + +// UserConfig holds user execution configuration. +type UserConfig struct { + Token windows.Handle + Environment []string +} + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") + procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList") + procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute") + procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList") +) + +// ExecutePtyWithUserToken executes a command with ConPty using user token. +func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { + args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) + commandLine := buildCommandLine(args) + + config := ExecutionConfig{ + Pty: ptyConfig, + User: userConfig, + Session: session, + Context: ctx, + } + + return executeConPtyWithConfig(commandLine, config) +} + +// ExecutionConfig holds all execution configuration. +type ExecutionConfig struct { + Pty PtyConfig + User UserConfig + Session ssh.Session + Context context.Context +} + +// executeConPtyWithConfig creates ConPty and executes process with configuration. +func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error { + ctx := config.Context + session := config.Session + width := config.Pty.Width + height := config.Pty.Height + userToken := config.User.Token + userEnv := config.User.Environment + workingDir := config.Pty.WorkingDir + + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + return fmt.Errorf("create ConPty pipes: %w", err) + } + + hPty, err := createConPty(width, height, inputRead, outputWrite) + if err != nil { + return fmt.Errorf("create ConPty: %w", err) + } + + primaryToken, err := duplicateToPrimaryToken(userToken) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("duplicate to primary token: %w", err) + } + defer func() { + if err := windows.CloseHandle(primaryToken); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + siEx, err := setupConPtyStartupInfo(hPty) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("setup startup info: %w", err) + } + defer func() { + _, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList))) + }() + + pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("create process as user with ConPty: %w", err) + } + defer closeProcessInfo(pi) + + if err := windows.CloseHandle(inputRead); err != nil { + log.Debugf(errCloseInputRead, err) + } + if err := windows.CloseHandle(outputWrite); err != nil { + log.Debugf("close output write handle: %v", err) + } + + return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process) +} + +// createConPtyPipes creates input/output pipes for ConPty. +func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) { + if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil { + return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err) + } + + if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil { + if closeErr := windows.CloseHandle(inputRead); closeErr != nil { + log.Debugf(errCloseInputRead, closeErr) + } + if closeErr := windows.CloseHandle(inputWrite); closeErr != nil { + log.Debugf("close input write handle: %v", closeErr) + } + return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err) + } + + return inputRead, inputWrite, outputRead, outputWrite, nil +} + +// createConPty creates a Windows ConPty with the specified size and pipe handles. +func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) { + size := windows.Coord{X: int16(width), Y: int16(height)} + + var hPty windows.Handle + if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil { + return 0, fmt.Errorf("CreatePseudoConsole: %w", err) + } + + return hPty, nil +} + +// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes. +func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) { + var siEx windows.StartupInfoEx + siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx)) + + var attrListSize uintptr + ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize))) + if ret == 0 && attrListSize == 0 { + return nil, fmt.Errorf("get attribute list size") + } + + attrListBytes := make([]byte, attrListSize) + siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0])) + + ret, _, err := procInitializeProcThreadAttributeList.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 1, + 0, + uintptr(unsafe.Pointer(&attrListSize)), + ) + if ret == 0 { + return nil, fmt.Errorf("initialize attribute list: %w", err) + } + + ret, _, err = procUpdateProcThreadAttribute.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 0, + procThreadAttributePseudoConsole, + uintptr(hPty), + unsafe.Sizeof(hPty), + 0, + 0, + ) + if ret == 0 { + return nil, fmt.Errorf("update thread attribute: %w", err) + } + + return &siEx, nil +} + +// createConPtyProcess creates the actual process with ConPty. +func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) { + var pi windows.ProcessInformation + creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment) + + commandLinePtr, err := windows.UTF16PtrFromString(commandLine) + if err != nil { + return nil, fmt.Errorf("convert command line to UTF16: %w", err) + } + + envPtr, err := convertEnvironmentToUTF16(userEnv) + if err != nil { + return nil, err + } + + var workingDirPtr *uint16 + if workingDir != "" { + workingDirPtr, err = windows.UTF16PtrFromString(workingDir) + if err != nil { + return nil, fmt.Errorf("convert working directory to UTF16: %w", err) + } + } + + siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES + siEx.StartupInfo.StdInput = windows.Handle(0) + siEx.StartupInfo.StdOutput = windows.Handle(0) + siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput + + if userToken != windows.InvalidHandle { + err = windows.CreateProcessAsUser( + windows.Token(userToken), + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } else { + err = windows.CreateProcess( + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } + + if err != nil { + return nil, fmt.Errorf("create process: %w", err) + } + + return &pi, nil +} + +// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format. +func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) { + if len(userEnv) == 0 { + // Return nil pointer for empty environment - Windows API will inherit parent environment + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment + } + + var envUTF16 []uint16 + for _, envVar := range userEnv { + if envVar != "" { + utf16Str, err := windows.UTF16FromString(envVar) + if err != nil { + log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err) + continue + } + envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...) + envUTF16 = append(envUTF16, 0) + } + } + envUTF16 = append(envUTF16, 0) + + if len(envUTF16) > 0 { + return &envUTF16[0], nil + } + // Return nil pointer when no valid environment variables found + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment +} + +// duplicateToPrimaryToken converts an impersonation token to a primary token. +func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) { + var primaryToken windows.Handle + if err := windows.DuplicateTokenEx( + windows.Token(token), + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + (*windows.Token)(&primaryToken), + ); err != nil { + return 0, fmt.Errorf("duplicate token: %w", err) + } + return primaryToken, nil +} + +// SessionExiter provides the Exit method for reporting process exit status. +type SessionExiter interface { + Exit(code int) error +} + +// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers. +func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error { + if err := ctx.Err(); err != nil { + return err + } + + var wg sync.WaitGroup + startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer) + + processErr := waitForProcess(ctx, process) + if processErr != nil { + return processErr + } + + var exitCode uint32 + if err := windows.GetExitCodeProcess(process, &exitCode); err != nil { + log.Debugf("get exit code: %v", err) + } else { + if err := session.Exit(int(exitCode)); err != nil { + log.Debugf("report exit code: %v", err) + } + } + + // Clean up in the original order after process completes + if err := reader.Close(); err != nil { + log.Debugf("close reader: %v", err) + } + + ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)) + if ret == 0 { + log.Debugf("close ConPty handle: %v", err) + } + + wg.Wait() + + if err := windows.CloseHandle(outputRead); err != nil { + log.Debugf("close output read handle: %v", err) + } + + return nil +} + +// startIOBridging starts the I/O bridging goroutines. +func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) { + wg.Add(2) + + // Input: reader (SSH session) -> inputWrite (ConPty) + go func() { + defer wg.Done() + defer func() { + if err := windows.CloseHandle(inputWrite); err != nil { + log.Debugf("close input write handle in goroutine: %v", err) + } + }() + + if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil { + log.Debugf("input copy ended with error: %v", err) + } + }() + + // Output: outputRead (ConPty) -> writer (SSH session) + go func() { + defer wg.Done() + if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil { + log.Debugf("output copy ended with error: %v", err) + } + }() +} + +// waitForProcess waits for process completion with context cancellation. +func waitForProcess(ctx context.Context, process windows.Handle) error { + if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil { + return fmt.Errorf("wait for process %d: %w", process, err) + } + return nil +} + +// buildShellArgs builds shell arguments for ConPty execution. +func buildShellArgs(shell, command string) []string { + if command != "" { + return []string{shell, PowerShellCommandFlag, command} + } + return []string{shell} +} + +// buildCommandLine builds a Windows command line from arguments using proper escaping. +func buildCommandLine(args []string) string { + if len(args) == 0 { + return "" + } + + var result strings.Builder + for i, arg := range args { + if i > 0 { + result.WriteString(" ") + } + result.WriteString(syscall.EscapeArg(arg)) + } + return result.String() +} + +// closeHandles closes multiple Windows handles. +func closeHandles(handles ...windows.Handle) { + for _, handle := range handles { + if handle != windows.InvalidHandle { + if err := windows.CloseHandle(handle); err != nil { + log.Debugf("close handle: %v", err) + } + } + } +} + +// closeProcessInfo closes process and thread handles. +func closeProcessInfo(pi *windows.ProcessInformation) { + if pi != nil { + if err := windows.CloseHandle(pi.Process); err != nil { + log.Debugf("close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Debugf("close thread handle: %v", err) + } + } +} + +// windowsHandleReader wraps a Windows handle for reading. +type windowsHandleReader struct { + handle windows.Handle +} + +func (r *windowsHandleReader) Read(p []byte) (n int, err error) { + var bytesRead uint32 + if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil { + return 0, err + } + return int(bytesRead), nil +} + +func (r *windowsHandleReader) Close() error { + return windows.CloseHandle(r.handle) +} + +// windowsHandleWriter wraps a Windows handle for writing. +type windowsHandleWriter struct { + handle windows.Handle +} + +func (w *windowsHandleWriter) Write(p []byte) (n int, err error) { + var bytesWritten uint32 + if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil { + return 0, err + } + return int(bytesWritten), nil +} + +func (w *windowsHandleWriter) Close() error { + return windows.CloseHandle(w.handle) +} diff --git a/client/ssh/server/winpty/conpty_test.go b/client/ssh/server/winpty/conpty_test.go new file mode 100644 index 000000000..4f04e1fad --- /dev/null +++ b/client/ssh/server/winpty/conpty_test.go @@ -0,0 +1,290 @@ +//go:build windows + +package winpty + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func TestBuildShellArgs(t *testing.T) { + tests := []struct { + name string + shell string + command string + expected []string + }{ + { + name: "Shell with command", + shell: "powershell.exe", + command: "Get-Process", + expected: []string{"powershell.exe", "-Command", "Get-Process"}, + }, + { + name: "CMD with command", + shell: "cmd.exe", + command: "dir", + expected: []string{"cmd.exe", "-Command", "dir"}, + }, + { + name: "Shell interactive", + shell: "powershell.exe", + command: "", + expected: []string{"powershell.exe"}, + }, + { + name: "CMD interactive", + shell: "cmd.exe", + command: "", + expected: []string{"cmd.exe"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildShellArgs(tt.shell, tt.command) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildCommandLine(t *testing.T) { + tests := []struct { + name string + args []string + expected string + }{ + { + name: "Simple args", + args: []string{"cmd.exe", "/c", "echo"}, + expected: "cmd.exe /c echo", + }, + { + name: "Args with spaces", + args: []string{"Program Files\\app.exe", "arg with spaces"}, + expected: `"Program Files\app.exe" "arg with spaces"`, + }, + { + name: "Args with quotes", + args: []string{"cmd.exe", "/c", `echo "hello world"`}, + expected: `cmd.exe /c "echo \"hello world\""`, + }, + { + name: "PowerShell calling PowerShell", + args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`}, + expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`, + }, + { + name: "Complex nested quotes", + args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`}, + expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`, + }, + { + name: "Path with spaces and args", + args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`}, + expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`, + }, + { + name: "Empty argument", + args: []string{"cmd.exe", "/c", "echo", ""}, + expected: `cmd.exe /c echo ""`, + }, + { + name: "Argument with backslashes", + args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"}, + expected: `robocopy C:\Source\ C:\Dest\ /E`, + }, + { + name: "Empty args", + args: []string{}, + expected: "", + }, + { + name: "Single arg with space", + args: []string{"path with spaces"}, + expected: `"path with spaces"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildCommandLine(tt.args) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCreateConPtyPipes(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + + // Verify all handles are valid + assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid") + + // Clean up handles + closeHandles(inputRead, inputWrite, outputRead, outputWrite) +} + +func TestCreateConPty(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + defer closeHandles(inputRead, inputWrite, outputRead, outputWrite) + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + require.NoError(t, err, "Should create ConPty successfully") + assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid") + + // Clean up ConPty + ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty)) + assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully") +} + +func TestConvertEnvironmentToUTF16(t *testing.T) { + tests := []struct { + name string + userEnv []string + hasError bool + }{ + { + name: "Valid environment variables", + userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"}, + hasError: false, + }, + { + name: "Empty environment", + userEnv: []string{}, + hasError: false, + }, + { + name: "Environment with empty strings", + userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"}, + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertEnvironmentToUTF16(tt.userEnv) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if len(tt.userEnv) == 0 { + assert.Nil(t, result, "Empty environment should return nil") + } else { + assert.NotNil(t, result, "Non-empty environment should return valid pointer") + } + } + }) + } +} + +func TestDuplicateToPrimaryToken(t *testing.T) { + if testing.Short() { + t.Skip("Skipping token tests in short mode") + } + + // Get current process token for testing + var token windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token) + require.NoError(t, err, "Should open current process token") + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + t.Logf("Failed to close token: %v", err) + } + }() + + primaryToken, err := duplicateToPrimaryToken(windows.Handle(token)) + require.NoError(t, err, "Should duplicate token to primary") + assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid") + + // Clean up + err = windows.CloseHandle(primaryToken) + assert.NoError(t, err, "Should close primary token") +} + +func TestWindowsHandleReader(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Write test data + testData := []byte("Hello, Windows Handle Reader!") + var bytesWritten uint32 + err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil) + require.NoError(t, err, "Should write test data") + require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data") + + // Close write handle to signal EOF + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + writeHandle = windows.InvalidHandle + + // Test reading + reader := &windowsHandleReader{handle: readHandle} + buffer := make([]byte, len(testData)) + n, err := reader.Read(buffer) + require.NoError(t, err, "Should read from handle") + assert.Equal(t, len(testData), n, "Should read expected number of bytes") + assert.Equal(t, testData, buffer, "Should read expected data") +} + +func TestWindowsHandleWriter(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Test writing + testData := []byte("Hello, Windows Handle Writer!") + writer := &windowsHandleWriter{handle: writeHandle} + n, err := writer.Write(testData) + require.NoError(t, err, "Should write to handle") + assert.Equal(t, len(testData), n, "Should write expected number of bytes") + + // Close write handle + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + + // Verify data was written by reading it back + buffer := make([]byte, len(testData)) + var bytesRead uint32 + err = windows.ReadFile(readHandle, buffer, &bytesRead, nil) + require.NoError(t, err, "Should read back written data") + assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes") + assert.Equal(t, testData, buffer, "Should read back expected data") +} + +// BenchmarkConPtyCreation benchmarks ConPty creation performance +func BenchmarkConPtyCreation(b *testing.B) { + for i := 0; i < b.N; i++ { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + b.Fatal(err) + } + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + if err != nil { + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + b.Fatal(err) + } + + // Clean up + if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 { + log.Debugf("ClosePseudoConsole failed: %v", err) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + } +} diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go deleted file mode 100644 index 76f43fd4e..000000000 --- a/client/ssh/server_mock.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build !js - -package ssh - -import "context" - -// MockServer mocks ssh.Server -type MockServer struct { - Ctx context.Context - StopFunc func() error - StartFunc func() error - AddAuthorizedKeyFunc func(peer, newKey string) error - RemoveAuthorizedKeyFunc func(peer string) -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *MockServer) RemoveAuthorizedKey(peer string) { - if srv.RemoveAuthorizedKeyFunc == nil { - return - } - srv.RemoveAuthorizedKeyFunc(peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error { - if srv.AddAuthorizedKeyFunc == nil { - return nil - } - return srv.AddAuthorizedKeyFunc(peer, newKey) -} - -// Stop stops SSH server. -func (srv *MockServer) Stop() error { - if srv.StopFunc == nil { - return nil - } - return srv.StopFunc() -} - -// Start starts SSH server. Blocking -func (srv *MockServer) Start() error { - if srv.StartFunc == nil { - return nil - } - return srv.StartFunc() -} diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go deleted file mode 100644 index 1f310c2bb..000000000 --- a/client/ssh/server_test.go +++ /dev/null @@ -1,123 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/ssh" - "strings" - "testing" -) - -func TestServer_AddAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - // add multiple keys - keys := map[string][]byte{} - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys[peer] = remotePubKey - } - - // make sure that all keys have been added - for peer, remotePubKey := range keys { - k, ok := server.authorizedKeys[peer] - assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys") - - assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k)))) - } - -} - -func TestServer_RemoveAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey("remotePeer", string(remotePubKey)) - if err != nil { - t.Error(err) - } - - server.RemoveAuthorizedKey("remotePeer") - - _, ok := server.authorizedKeys["remotePeer"] - assert.False(t, ok, "expecting remotePeer's SSH key to be removed") -} - -func TestServer_PubKeyHandler(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - var keys []ssh.PublicKey - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys = append(keys, remoteParsedPubKey) - } - - for _, key := range keys { - accepted := server.publicKeyHandler(nil, key) - - assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key))) - } - -} diff --git a/client/ssh/util.go b/client/ssh/ssh.go similarity index 86% rename from client/ssh/util.go rename to client/ssh/ssh.go index a54a609bc..c0024c599 100644 --- a/client/ssh/util.go +++ b/client/ssh/ssh.go @@ -32,9 +32,8 @@ const RSA KeyType = "rsa" // RSAKeySize is a size of newly generated RSA key const RSAKeySize = 2048 -// GeneratePrivateKey creates RSA Private Key of specified byte size +// GeneratePrivateKey creates a private key of the specified type. func GeneratePrivateKey(keyType KeyType) ([]byte, error) { - var key crypto.Signer var err error switch keyType { @@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) { return pemBytes, nil } -// GeneratePublicKey returns the public part of the private key +// GeneratePublicKey returns the public part of the private key. func GeneratePublicKey(key []byte) ([]byte, error) { signer, err := gossh.ParsePrivateKey(key) if err != nil { @@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) { return []byte(strKey), nil } -// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format +// EncodePrivateKeyToPEM encodes a private key to PEM format. func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) { mk, err := x509.MarshalPKCS8PrivateKey(privateKey) if err != nil { return nil, err } - // pem.Block privBlock := pem.Block{ Type: "PRIVATE KEY", Bytes: mk, } - - // Private key in PEM format privatePEM := pem.EncodeToMemory(&privBlock) return privatePEM, nil } diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go new file mode 100644 index 000000000..0c1222078 --- /dev/null +++ b/client/ssh/testutil/user_helpers.go @@ -0,0 +1,172 @@ +package testutil + +import ( + "fmt" + "log" + "os" + "os/exec" + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +var testCreatedUsers = make(map[string]bool) +var testUsersToCleanup []string + +// GetTestUsername returns an appropriate username for testing +func GetTestUsername(t *testing.T) string { + if runtime.GOOS == "windows" { + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + if IsSystemAccount(currentUser.Username) { + if IsCI() { + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } else { + if _, err := user.Lookup("Administrator"); err == nil { + return "Administrator" + } + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } + } + return currentUser.Username + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + return currentUser.Username +} + +// IsCI checks if we're running in a CI environment +func IsCI() bool { + if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" { + return true + } + + hostname, err := os.Hostname() + if err == nil && strings.HasPrefix(hostname, "runner") { + return true + } + + return false +} + +// IsSystemAccount checks if the user is a system account that can't authenticate +func IsSystemAccount(username string) bool { + systemAccounts := []string{ + "system", + "NT AUTHORITY\\SYSTEM", + "NT AUTHORITY\\LOCAL SERVICE", + "NT AUTHORITY\\NETWORK SERVICE", + } + + for _, sysAccount := range systemAccounts { + if strings.EqualFold(username, sysAccount) { + return true + } + } + return false +} + +// RegisterTestUserCleanup registers a test user for cleanup +func RegisterTestUserCleanup(username string) { + if !testCreatedUsers[username] { + testCreatedUsers[username] = true + testUsersToCleanup = append(testUsersToCleanup, username) + } +} + +// CleanupTestUsers removes all created test users +func CleanupTestUsers() { + for _, username := range testUsersToCleanup { + RemoveWindowsTestUser(username) + } + testUsersToCleanup = nil + testCreatedUsers = make(map[string]bool) +} + +// GetOrCreateTestUser creates a test user on Windows if needed +func GetOrCreateTestUser(t *testing.T) string { + testUsername := "netbird-test-user" + + if _, err := user.Lookup(testUsername); err == nil { + return testUsername + } + + if CreateWindowsTestUser(t, testUsername) { + RegisterTestUserCleanup(testUsername) + return testUsername + } + + return "" +} + +// RemoveWindowsTestUser removes a local user on Windows using PowerShell +func RemoveWindowsTestUser(username string) { + if runtime.GOOS != "windows" { + return + } + + psCmd := fmt.Sprintf(` + try { + Remove-LocalUser -Name "%s" -ErrorAction Stop + Write-Output "User removed successfully" + } catch { + if ($_.Exception.Message -like "*cannot be found*") { + Write-Output "User not found (already removed)" + } else { + Write-Error $_.Exception.Message + } + } + `, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output)) + } else { + log.Printf("Test user %s cleanup result: %s", username, string(output)) + } +} + +// CreateWindowsTestUser creates a local user on Windows using PowerShell +func CreateWindowsTestUser(t *testing.T, username string) bool { + if runtime.GOOS != "windows" { + return false + } + + psCmd := fmt.Sprintf(` + try { + $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force + New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires + Add-LocalGroupMember -Group "Users" -Member "%s" + Write-Output "User created successfully" + } catch { + if ($_.Exception.Message -like "*already exists*") { + Write-Output "User already exists" + } else { + Write-Error $_.Exception.Message + exit 1 + } + } + `, username, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("Failed to create test user: %v, output: %s", err, string(output)) + return false + } + + t.Logf("Test user creation result: %s", string(output)) + return true +} diff --git a/client/ssh/window_freebsd.go b/client/ssh/window_freebsd.go deleted file mode 100644 index ef4848341..000000000 --- a/client/ssh/window_freebsd.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build freebsd - -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { -} diff --git a/client/ssh/window_unix.go b/client/ssh/window_unix.go deleted file mode 100644 index 2891eb70e..000000000 --- a/client/ssh/window_unix.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux || darwin - -package ssh - -import ( - "os" - "syscall" - "unsafe" -) - -func setWinSize(file *os.File, width, height int) { - syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint - uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0}))) -} diff --git a/client/ssh/window_windows.go b/client/ssh/window_windows.go deleted file mode 100644 index 5abd41f27..000000000 --- a/client/ssh/window_windows.go +++ /dev/null @@ -1,9 +0,0 @@ -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { - -} diff --git a/client/status/status.go b/client/status/status.go index 8a0b7bae0..d975f0e29 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -81,6 +81,18 @@ type NsServerGroupStateOutput struct { Error string `json:"error" yaml:"error"` } +type SSHSessionOutput struct { + Username string `json:"username" yaml:"username"` + RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` + Command string `json:"command" yaml:"command"` + JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` +} + +type SSHServerStateOutput struct { + Enabled bool `json:"enabled" yaml:"enabled"` + Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"` +} + type OutputOverview struct { Peers PeersStateOutput `json:"peers" yaml:"peers"` CliVersion string `json:"cliVersion" yaml:"cliVersion"` @@ -100,6 +112,7 @@ type OutputOverview struct { Events []SystemEventOutput `json:"events" yaml:"events"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` ProfileName string `json:"profileName" yaml:"profileName"` + SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"` } func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { @@ -121,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status relayOverview := mapRelays(pbFullStatus.GetRelays()) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) overview := OutputOverview{ Peers: peersOverview, @@ -141,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), ProfileName: profName, + SSHServerState: sshServerOverview, } if anon { @@ -190,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput { return mappedNSGroups } +func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput { + if sshServerState == nil { + return SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + } + } + + sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions())) + for _, session := range sshServerState.GetSessions() { + sessions = append(sessions, SSHSessionOutput{ + Username: session.GetUsername(), + RemoteAddress: session.GetRemoteAddress(), + Command: session.GetCommand(), + JWTUsername: session.GetJwtUsername(), + }) + } + + return SSHServerStateOutput{ + Enabled: sshServerState.GetEnabled(), + Sessions: sessions, + } +} + func mapPeers( peers []*proto.PeerState, statusFilter string, @@ -300,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) { return string(yamlBytes), nil } -func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string { +func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { var managementConnString string if overview.ManagementState.Connected { managementConnString = "Connected" @@ -405,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, lazyConnectionEnabledStatus = "true" } + sshServerStatus := "Disabled" + if overview.SSHServerState.Enabled { + sessionCount := len(overview.SSHServerState.Sessions) + if sessionCount > 0 { + sessionWord := "session" + if sessionCount > 1 { + sessionWord = "sessions" + } + sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord) + } else { + sshServerStatus = "Enabled" + } + + if showSSHSessions && sessionCount > 0 { + for _, session := range overview.SSHServerState.Sessions { + var sessionDisplay string + if session.JWTUsername != "" { + sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", + session.JWTUsername, + session.RemoteAddress, + session.Username, + session.Command, + ) + } else { + sessionDisplay = fmt.Sprintf("[%s@%s] %s", + session.Username, + session.RemoteAddress, + session.Command, + ) + } + sshServerStatus += "\n " + sessionDisplay + } + } + } + peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) goos := runtime.GOOS @@ -428,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Lazy connection: %s\n"+ + "SSH Server: %s\n"+ "Networks: %s\n"+ "Forwarding rules: %d\n"+ "Peers count: %s\n", @@ -444,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, + sshServerStatus, networks, overview.NumberOfForwardingRules, peersCountString, @@ -454,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, func ParseToFullDetailSummary(overview OutputOverview) string { parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) parsedEventsString := parseEvents(overview.Events) - summary := ParseGeneralSummary(overview, true, true, true) + summary := ParseGeneralSummary(overview, true, true, true, true) return fmt.Sprintf( "Peers detail:"+ @@ -746,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { event.Metadata[k] = a.AnonymizeString(v) } } + + for i, session := range overview.SSHServerState.Sessions { + if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil { + overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port) + } else { + overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress) + } + overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command) + } } diff --git a/client/status/status_test.go b/client/status/status_test.go index 660efd9ef..1dca1e5b1 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -231,6 +231,10 @@ var overview = OutputOverview{ Networks: []string{ "10.10.0.0/24", }, + SSHServerState: SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) { ], "events": [], "lazyConnectionEnabled": false, - "profileName":"" + "profileName":"", + "sshServer":{ + "enabled":false, + "sessions":[] + } }` // @formatter:on @@ -488,6 +496,9 @@ dnsServers: events: [] lazyConnectionEnabled: false profileName: "" +sshServer: + enabled: false + sessions: [] ` assert.Equal(t, expectedYAML, yaml) @@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected @@ -563,7 +575,7 @@ Peers count: 2/2 Connected } func TestParsingToShortVersion(t *testing.T) { - shortVersion := ParseGeneralSummary(overview, false, false, false) + shortVersion := ParseGeneralSummary(overview, false, false, false, false) expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` Daemon version: 0.14.1 @@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected diff --git a/client/system/info.go b/client/system/info.go index a180be4c0..01176e765 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -72,6 +72,12 @@ type Info struct { BlockInbound bool LazyConnectionEnabled bool + + EnableSSHRoot bool + EnableSSHSFTP bool + EnableSSHLocalPortForwarding bool + EnableSSHRemotePortForwarding bool + DisableSSHAuth bool } func (i *Info) SetFlags( @@ -79,6 +85,8 @@ func (i *Info) SetFlags( serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, + enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, + disableSSHAuth *bool, ) { i.RosenpassEnabled = rosenpassEnabled i.RosenpassPermissive = rosenpassPermissive @@ -94,6 +102,22 @@ func (i *Info) SetFlags( i.BlockInbound = blockInbound i.LazyConnectionEnabled = lazyConnectionEnabled + + if enableSSHRoot != nil { + i.EnableSSHRoot = *enableSSHRoot + } + if enableSSHSFTP != nil { + i.EnableSSHSFTP = *enableSSHSFTP + } + if enableSSHLocalPortForwarding != nil { + i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding + } + if enableSSHRemotePortForwarding != nil { + i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding + } + if disableSSHAuth != nil { + i.DisableSSHAuth = *disableSSHAuth + } } // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d3ab42423..44643616d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -55,6 +55,7 @@ const ( const ( censoredPreSharedKey = "**********" + maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds ) func main() { @@ -265,25 +266,38 @@ type serviceClient struct { iMTU *widget.Entry // switch elements for settings form - sRosenpassPermissive *widget.Check - sNetworkMonitor *widget.Check - sDisableDNS *widget.Check - sDisableClientRoutes *widget.Check - sDisableServerRoutes *widget.Check - sBlockLANAccess *widget.Check + sRosenpassPermissive *widget.Check + sNetworkMonitor *widget.Check + sDisableDNS *widget.Check + sDisableClientRoutes *widget.Check + sDisableServerRoutes *widget.Check + sBlockLANAccess *widget.Check + sEnableSSHRoot *widget.Check + sEnableSSHSFTP *widget.Check + sEnableSSHLocalPortForward *widget.Check + sEnableSSHRemotePortForward *widget.Check + sDisableSSHAuth *widget.Check + iSSHJWTCacheTTL *widget.Entry // observable settings over corresponding iMngURL and iPreSharedKey values. - managementURL string - preSharedKey string - RosenpassPermissive bool - interfaceName string - interfacePort int - mtu uint16 - networkMonitor bool - disableDNS bool - disableClientRoutes bool - disableServerRoutes bool - blockLANAccess bool + managementURL string + preSharedKey string + + RosenpassPermissive bool + interfaceName string + interfacePort int + mtu uint16 + networkMonitor bool + disableDNS bool + disableClientRoutes bool + disableServerRoutes bool + blockLANAccess bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int connected bool update *version.Update @@ -435,18 +449,22 @@ func (s *serviceClient) showSettingsUI() { s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil) s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil) s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil) + s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil) + s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) + s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil) + s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil) + s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil) + s.iSSHJWTCacheTTL = widget.NewEntry() s.wSettings.SetContent(s.getSettingsForm()) - s.wSettings.Resize(fyne.NewSize(600, 500)) + s.wSettings.Resize(fyne.NewSize(600, 400)) s.wSettings.SetFixedSize(true) s.getSrvConfig() s.wSettings.Show() } -// getSettingsForm to embed it into settings window. -func (s *serviceClient) getSettingsForm() *widget.Form { - +func (s *serviceClient) getConnectionForm() *widget.Form { var activeProfName string activeProf, err := s.profileManager.GetActiveProfile() if err != nil { @@ -457,153 +475,277 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return &widget.Form{ Items: []*widget.FormItem{ {Text: "Profile", Widget: widget.NewLabel(activeProfName)}, + {Text: "Management URL", Widget: s.iMngURL}, + {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "MTU", Widget: s.iMTU}, - {Text: "Management URL", Widget: s.iMngURL}, - {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Log File", Widget: s.iLogFile}, + }, + } +} + +func (s *serviceClient) saveSettings() { + // Check if update settings are disabled by daemon + features, err := s.getFeatures() + if err != nil { + log.Errorf("failed to get features from daemon: %v", err) + // Continue with default behavior if features can't be retrieved + } else if features != nil && features.DisableUpdateSettings { + log.Warn("Configuration updates are disabled by daemon") + dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) + return + } + + if err := s.validateSettings(); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + port, mtu, err := s.parseNumericSettings() + if err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + iMngURL := strings.TrimSpace(s.iMngURL.Text) + + if s.hasSettingsChanged(iMngURL, port, mtu) { + if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + } + + s.wSettings.Close() +} + +func (s *serviceClient) validateSettings() error { + if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { + if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { + return fmt.Errorf("Invalid Pre-shared Key Value") + } + } + return nil +} + +func (s *serviceClient) parseNumericSettings() (int64, int64, error) { + port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid interface port") + } + if port < 1 || port > 65535 { + return 0, 0, errors.New("Invalid interface port: out of range 1-65535") + } + + var mtu int64 + mtuText := strings.TrimSpace(s.iMTU.Text) + if mtuText != "" { + mtu, err = strconv.ParseInt(mtuText, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid MTU value") + } + if mtu < iface.MinMTU || mtu > iface.MaxMTU { + return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU) + } + } + + return port, mtu, nil +} + +func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool { + return s.managementURL != iMngURL || + s.preSharedKey != s.iPreSharedKey.Text || + s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.interfaceName != s.iInterfaceName.Text || + s.interfacePort != int(port) || + s.mtu != uint16(mtu) || + s.networkMonitor != s.sNetworkMonitor.Checked || + s.disableDNS != s.sDisableDNS.Checked || + s.disableClientRoutes != s.sDisableClientRoutes.Checked || + s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.blockLANAccess != s.sBlockLANAccess.Checked || + s.hasSSHChanges() +} + +func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error { + s.managementURL = iMngURL + s.preSharedKey = s.iPreSharedKey.Text + s.mtu = uint16(mtu) + + req, err := s.buildSetConfigRequest(iMngURL, port, mtu) + if err != nil { + return fmt.Errorf("build config request: %w", err) + } + + if err := s.sendConfigUpdate(req); err != nil { + return fmt.Errorf("set configuration: %w", err) + } + + return nil +} + +func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) { + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + return nil, fmt.Errorf("get active profile: %w", err) + } + + req := &proto.SetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + } + + if iMngURL != "" { + req.ManagementUrl = iMngURL + } + + req.RosenpassPermissive = &s.sRosenpassPermissive.Checked + req.InterfaceName = &s.iInterfaceName.Text + req.WireguardPort = &port + if mtu > 0 { + req.Mtu = &mtu + } + + req.NetworkMonitor = &s.sNetworkMonitor.Checked + req.DisableDns = &s.sDisableDNS.Checked + req.DisableClientRoutes = &s.sDisableClientRoutes.Checked + req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.BlockLanAccess = &s.sBlockLANAccess.Checked + + req.EnableSSHRoot = &s.sEnableSSHRoot.Checked + req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked + req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked + req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked + req.DisableSSHAuth = &s.sDisableSSHAuth.Checked + + sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text) + if sshJWTCacheTTLText != "" { + sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32) + if err != nil { + return nil, errors.New("Invalid SSH JWT Cache TTL value") + } + if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL { + return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL) + } + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + + if s.iPreSharedKey.Text != censoredPreSharedKey { + req.OptionalPreSharedKey = &s.iPreSharedKey.Text + } + + return req, nil +} + +func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return fmt.Errorf("get client: %w", err) + } + + _, err = conn.SetConfig(s.ctx, req) + if err != nil { + return fmt.Errorf("set config: %w", err) + } + + // Reconnect if connected to apply the new settings + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + } + }() + + return nil +} + +func (s *serviceClient) getSettingsForm() fyne.CanvasObject { + connectionForm := s.getConnectionForm() + networkForm := s.getNetworkForm() + sshForm := s.getSSHForm() + tabs := container.NewAppTabs( + container.NewTabItem("Connection", connectionForm), + container.NewTabItem("Network", networkForm), + container.NewTabItem("SSH", sshForm), + ) + saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings) + saveButton.Importance = widget.HighImportance + cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() { + s.wSettings.Close() + }) + buttonContainer := container.NewHBox( + layout.NewSpacer(), + cancelButton, + saveButton, + ) + return container.NewBorder(nil, buttonContainer, nil, nil, tabs) +} + +func (s *serviceClient) getNetworkForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ {Text: "Network Monitor", Widget: s.sNetworkMonitor}, {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, - SubmitText: "Save", - OnSubmit: func() { - // Check if update settings are disabled by daemon - features, err := s.getFeatures() - if err != nil { - log.Errorf("failed to get features from daemon: %v", err) - // Continue with default behavior if features can't be retrieved - } else if features != nil && features.DisableUpdateSettings { - log.Warn("Configuration updates are disabled by daemon") - dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) - return - } + } +} - if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { - // validate preSharedKey if it added - if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { - dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings) - return - } - } - - port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) - return - } - - var mtu int64 - mtuText := strings.TrimSpace(s.iMTU.Text) - if mtuText != "" { - var err error - mtu, err = strconv.ParseInt(mtuText, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings) - return - } - if mtu < iface.MinMTU || mtu > iface.MaxMTU { - dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings) - return - } - } - - iMngURL := strings.TrimSpace(s.iMngURL.Text) - - defer s.wSettings.Close() - - // Check if any settings have changed - if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || - s.RosenpassPermissive != s.sRosenpassPermissive.Checked || - s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || - s.mtu != uint16(mtu) || - s.networkMonitor != s.sNetworkMonitor.Checked || - s.disableDNS != s.sDisableDNS.Checked || - s.disableClientRoutes != s.sDisableClientRoutes.Checked || - s.disableServerRoutes != s.sDisableServerRoutes.Checked || - s.blockLANAccess != s.sBlockLANAccess.Checked { - - s.managementURL = iMngURL - s.preSharedKey = s.iPreSharedKey.Text - s.mtu = uint16(mtu) - - currUser, err := user.Current() - if err != nil { - log.Errorf("get current user: %v", err) - return - } - - var req proto.SetConfigRequest - req.ProfileName = activeProf.Name - req.Username = currUser.Username - - if iMngURL != "" { - req.ManagementUrl = iMngURL - } - - req.RosenpassPermissive = &s.sRosenpassPermissive.Checked - req.InterfaceName = &s.iInterfaceName.Text - req.WireguardPort = &port - if mtu > 0 { - req.Mtu = &mtu - } - req.NetworkMonitor = &s.sNetworkMonitor.Checked - req.DisableDns = &s.sDisableDNS.Checked - req.DisableClientRoutes = &s.sDisableClientRoutes.Checked - req.DisableServerRoutes = &s.sDisableServerRoutes.Checked - req.BlockLanAccess = &s.sBlockLANAccess.Checked - - if s.iPreSharedKey.Text != censoredPreSharedKey { - req.OptionalPreSharedKey = &s.iPreSharedKey.Text - } - - conn, err := s.getSrvClient(failFastTimeout) - if err != nil { - log.Errorf("get client: %v", err) - dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings) - return - } - _, err = conn.SetConfig(s.ctx, &req) - if err != nil { - log.Errorf("set config: %v", err) - dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings) - return - } - - go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) - return - } - } - }() - } - }, - OnCancel: func() { - s.wSettings.Close() +func (s *serviceClient) getSSHForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ + {Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot}, + {Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP}, + {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward}, + {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward}, + {Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth}, + {Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL}, }, } } +func (s *serviceClient) hasSSHChanges() bool { + currentSSHJWTCacheTTL := s.sshJWTCacheTTL + if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" { + val, err := strconv.Atoi(text) + if err != nil { + return true + } + currentSSHJWTCacheTTL = val + } + + return s.enableSSHRoot != s.sEnableSSHRoot.Checked || + s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || + s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || + s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked || + s.disableSSHAuth != s.sDisableSSHAuth.Checked || + s.sshJWTCacheTTL != currentSSHJWTCacheTTL +} + func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -1123,6 +1265,25 @@ func (s *serviceClient) getSrvConfig() { s.disableServerRoutes = cfg.DisableServerRoutes s.blockLANAccess = cfg.BlockLANAccess + if cfg.EnableSSHRoot != nil { + s.enableSSHRoot = *cfg.EnableSSHRoot + } + if cfg.EnableSSHSFTP != nil { + s.enableSSHSFTP = *cfg.EnableSSHSFTP + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding + } + if cfg.DisableSSHAuth != nil { + s.disableSSHAuth = *cfg.DisableSSHAuth + } + if cfg.SSHJWTCacheTTL != nil { + s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL + } + if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) s.iPreSharedKey.SetText(cfg.PreSharedKey) @@ -1143,6 +1304,24 @@ func (s *serviceClient) getSrvConfig() { s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) + if cfg.EnableSSHRoot != nil { + s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot) + } + if cfg.EnableSSHSFTP != nil { + s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP) + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding) + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding) + } + if cfg.DisableSSHAuth != nil { + s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth) + } + if cfg.SSHJWTCacheTTL != nil { + s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL)) + } } if s.mNotifications == nil { @@ -1213,6 +1392,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.DisableServerRoutes = cfg.DisableServerRoutes config.BlockLANAccess = cfg.BlockLanAccess + config.EnableSSHRoot = &cfg.EnableSSHRoot + config.EnableSSHSFTP = &cfg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding + config.DisableSSHAuth = &cfg.DisableSSHAuth + + ttl := int(cfg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + return &config } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d542e2739..4dc14a1ca 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" netbird "github.com/netbirdio/netbird/client/embed" + sshdetection "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func { username = args[2].String() } + var jwtToken string + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + return createPromise(func(resolve, reject js.Value) { sshClient := ssh.NewClient(client) - if err := sshClient.Connect(host, port, username); err != nil { + if err := sshClient.Connect(host, port, username, jwtToken); err != nil { reject.Invoke(err.Error()) return } @@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value { })) } +// createDetectSSHServerMethod creates the SSH server detection method +func createDetectSSHServerMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + serverType, err := detectSSHServerType(ctx, client, host, port) + if err != nil { + reject.Invoke(err.Error()) + return + } + + resolve.Invoke(js.ValueOf(serverType.RequiresJWT())) + }) + }) +} + +// detectSSHServerType detects SSH server type using NetBird network connection +func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) { + return sshdetection.DetectSSHServerType(ctx, client, host, port) +} + // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{}) obj["start"] = createStartMethod(client) obj["stop"] = createStopMethod(client) + obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["createSSHConnection"] = createSSHMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client) diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index ca35525eb..568437e56 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" netbird "github.com/netbirdio/netbird/client/embed" + nbssh "github.com/netbirdio/netbird/client/ssh" ) const ( @@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client { } // Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username string) error { +func (c *Client) Connect(host string, port int, username, jwtToken string) error { addr := fmt.Sprintf("%s:%d", host, port) logrus.Infof("SSH: Connecting to %s as %s", addr, username) - var authMethods []ssh.AuthMethod - - nbConfig, err := c.nbClient.GetConfig() + authMethods, err := c.getAuthMethods(jwtToken) if err != nil { - return fmt.Errorf("get NetBird config: %w", err) + return err } - if nbConfig.SSHKey == "" { - return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") - } - - signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) - if err != nil { - return fmt.Errorf("parse NetBird SSH private key: %w", err) - } - - pubKey := signer.PublicKey() - logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) - - authMethods = append(authMethods, ssh.PublicKeys(signer)) config := &ssh.ClientConfig{ User: username, Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient), Timeout: sshDialTimeout, } @@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error { return nil } +// getAuthMethods returns SSH authentication methods, preferring JWT if available +func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) { + if jwtToken != "" { + logrus.Debugf("SSH: Using JWT password authentication") + return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil + } + + logrus.Debugf("SSH: No JWT token, using public key authentication") + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return nil, fmt.Errorf("get NetBird config: %w", err) + } + + if nbConfig.SSHKey == "" { + return nil, fmt.Errorf("no NetBird SSH key available") + } + + signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return nil, fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + logrus.Debugf("SSH: Added public key auth") + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil +} + // StartSession starts an SSH session with PTY func (c *Client) StartSession(cols, rows int) error { if c.sshClient == nil { diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go deleted file mode 100644 index 4868ba30a..000000000 --- a/client/wasm/internal/ssh/key.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build js - -package ssh - -import ( - "crypto/x509" - "encoding/pem" - "fmt" - "strings" - - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" -) - -// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format -func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { - keyStr := string(keyPEM) - if !strings.Contains(keyStr, "-----BEGIN") { - keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") - } - - signer, err := ssh.ParsePrivateKey(keyPEM) - if err == nil { - return signer, nil - } - logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) - - block, _ := pem.Decode(keyPEM) - if block == nil { - keyPreview := string(keyPEM) - if len(keyPreview) > 100 { - keyPreview = keyPreview[:100] - } - return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) - } - - key, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) - if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(rsaKey) - } - if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(ecKey) - } - return nil, fmt.Errorf("parse private key: %w", err) - } - - return ssh.NewSignerFromKey(key) -} diff --git a/go.mod b/go.mod index 2d7e0d31c..45a36190d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.23.0 +go 1.23.1 require ( cunicu.li/go-rosenpass v0.4.0 @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.40.0 - golang.org/x/sys v0.34.0 + golang.org/x/crypto v0.41.0 + golang.org/x/sys v0.35.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -31,6 +31,7 @@ require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible + github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 @@ -76,6 +77,7 @@ require ( github.com/pion/stun/v3 v3.0.0 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 + github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 @@ -108,7 +110,7 @@ require ( golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 - golang.org/x/term v0.33.0 + golang.org/x/term v0.34.0 golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 @@ -130,6 +132,7 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/awnumar/memcall v0.4.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect @@ -197,6 +200,7 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -248,8 +252,8 @@ require ( go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.24.0 // indirect - golang.org/x/text v0.27.0 // indirect - golang.org/x/tools v0.34.0 // indirect + golang.org/x/text v0.28.0 // indirect + golang.org/x/tools v0.35.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/go.sum b/go.sum index f4b62dff0..ec68a8f59 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,10 @@ github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJT github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= +github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= +github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= +github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= @@ -303,6 +307,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -432,6 +438,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= +github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= +github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= @@ -587,9 +595,13 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= @@ -607,6 +619,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -628,7 +643,10 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -643,6 +661,10 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -675,19 +697,28 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= -golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -695,9 +726,12 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -713,8 +747,10 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= -golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 409bdaaba..18a8427be 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -66,7 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 9a4681eae..7f64034df 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -3,6 +3,8 @@ package grpc import ( "context" "fmt" + "net/url" + "strings" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbdns "github.com/netbirdio/netbird/dns" @@ -81,12 +83,21 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) + + sshConfig := &proto.SSHConfig{ + SshEnabled: peer.SSHEnabled, + } + + if peer.SSHEnabled { + sshConfig.JwtConfig = buildJWTConfig(config) + } + return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), + SshConfig: sshConfig, Fqdn: fqdn, RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, @@ -95,7 +106,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -350,3 +361,51 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } + +// buildJWTConfig constructs JWT configuration for SSH servers from management server config +func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig { + if config == nil { + return nil + } + + if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" { + return nil + } + + issuer := strings.TrimSpace(config.HttpConfig.AuthIssuer) + if issuer == "" { + if config.DeviceAuthorizationFlow != nil { + if d := deriveIssuerFromTokenEndpoint(config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint); d != "" { + issuer = d + } + } + } + if issuer == "" { + return nil + } + + keysLocation := strings.TrimSpace(config.HttpConfig.AuthKeysLocation) + if keysLocation == "" { + keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" + } + + return &proto.JWTConfig{ + Issuer: issuer, + Audience: config.HttpConfig.AuthAudience, + KeysLocation: keysLocation, + } +} + +// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint +func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string { + if tokenEndpoint == "" { + return "" + } + + u, err := url.Parse(tokenEndpoint) + if err != nil { + return "" + } + + return fmt.Sprintf("%s://%s/", u.Scheme, u.Host) +} diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 08a840316..4364272a0 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -646,7 +646,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config), Checks: toProtocolChecks(ctx, postureChecks), } diff --git a/management/server/account.go b/management/server/account.go index a4b2a752b..3e498536c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -15,6 +15,8 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/shared/auth" + cacheStore "github.com/eko/gocache/lib/v4/store" "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" @@ -25,6 +27,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -81,6 +84,9 @@ type DefaultAccountManager struct { proxyController port_forwarding.Controller settingsManager settings.Manager + // config contains the management server configuration + config *nbconfig.Config + // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. @@ -171,6 +177,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, + config *nbconfig.Config, store store.Store, networkMapController network_map.Controller, idpManager idp.Manager, @@ -192,6 +199,7 @@ func BuildManager( am := &DefaultAccountManager{ Store: store, + config: config, geo: geo, networkMapController: networkMapController, idpManager: idpManager, @@ -1006,7 +1014,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth, primaryDomain bool, ) error { if userAuth.Domain == "" { @@ -1055,7 +1063,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - userAuth nbcontext.UserAuth, + userAuth auth.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) @@ -1074,7 +1082,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } @@ -1105,7 +1113,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID @@ -1217,7 +1225,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { - log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) return nil, err } @@ -1269,7 +1277,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac return newOnboarding, nil } -func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } @@ -1313,7 +1321,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager -func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { if userAuth.IsChild || userAuth.IsPAT { return nil } @@ -1471,7 +1479,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // // UserAuth IsChild -> checks that account exists -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) @@ -1550,7 +1558,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) { userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) @@ -1598,7 +1606,7 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { +func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool { return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 7c174a481..9b3902d87 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -6,10 +6,11 @@ import ( "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -45,10 +46,10 @@ type Manager interface { GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) - GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) - GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -117,11 +118,11 @@ type Manager interface { UpdateAccountPeers(ctx context.Context, accountID string) BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error + SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index ee9950796..10d718bbf 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -28,7 +28,6 @@ import ( nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -45,6 +44,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" ) func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) { @@ -445,7 +445,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams nbcontext.UserAuth + type initUserParams auth.UserAuth var ( publicDomain = "public.com" @@ -468,7 +468,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims nbcontext.UserAuth + inputClaims auth.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -483,7 +483,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -500,7 +500,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -517,7 +517,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -534,7 +534,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -552,7 +552,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -569,7 +569,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -587,7 +587,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -616,7 +616,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -656,7 +656,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, @@ -915,13 +915,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := nbcontext.UserAuth{ + publicClaims := auth.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, @@ -2709,7 +2709,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group3"}, @@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2738,7 +2738,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2759,7 +2759,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"])) - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1", "group2"}, @@ -2791,7 +2791,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group2"}, @@ -2805,7 +2805,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{"group1", "group3"}, @@ -2823,7 +2823,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2838,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{}, @@ -2959,7 +2959,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } @@ -3692,7 +3692,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { // Test adding new user to existing account with approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3722,7 +3722,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { } // Create a domain-based account without user approval - ownerUserAuth := nbcontext.UserAuth{ + ownerUserAuth := auth.UserAuth{ UserId: "owner-user", Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3741,7 +3741,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { // Test adding new user to existing account without approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index ece9dc321..0c62357dc 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -9,18 +9,19 @@ import ( "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/base62" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) var _ Manager = (*manager)(nil) type Manager interface { - ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsed(ctx context.Context, tokenID string) error GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } @@ -55,20 +56,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s } } -func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { token, err := m.validator.ValidateAndParse(ctx, value) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } userAuth, err := m.extractor.ToUserAuth(token) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } return userAuth, token, err } -func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index 30a7a7161..edf158a49 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,9 +3,10 @@ package auth import ( "context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/golang-jwt/jwt/v5" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" ) @@ -15,18 +16,18 @@ var ( // @note really dislike this mocking approach but rather than have to do additional test refactoring. type MockManager struct { - ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsedFunc func(ctx context.Context, tokenID string) error GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } // EnsureUserAccessByJWTGroups implements Manager. -func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if m.EnsureUserAccessByJWTGroupsFunc != nil { return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) } - return nbcontext.UserAuth{}, nil + return auth.UserAuth{}, nil } // GetPATInfo implements Manager. @@ -46,9 +47,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { } // ValidateAndParseToken implements Manager. -func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { if m.ValidateAndParseTokenFunc != nil { return m.ValidateAndParseTokenFunc(ctx, value) } - return nbcontext.UserAuth{}, &jwt.Token{}, nil + return auth.UserAuth{}, &jwt.Token{}, nil } diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index c8015eb37..b9f091b1e 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -17,10 +17,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { @@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { } // this has been validated and parsed by ValidateAndParseToken - userAuth := nbcontext.UserAuth{ + userAuth := nbauth.UserAuth{ AccountId: account.Id, Domain: domain, UserId: userId, @@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tests := []struct { name string tokenFunc func() string - expected *nbcontext.UserAuth // nil indicates expected error + expected *nbauth.UserAuth // nil indicates expected error }{ { name: "Valid with custom claims", @@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", AccountId: "account-id|567", Domain: "http://localhost", @@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", }, }, diff --git a/management/server/context/auth.go b/management/server/context/auth.go index 5cb28ddb7..cc59b8a63 100644 --- a/management/server/context/auth.go +++ b/management/server/context/auth.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "net/http" - "time" + + "github.com/netbirdio/netbird/shared/auth" ) type key int @@ -13,45 +14,22 @@ const ( UserAuthContextKey key = iota ) -type UserAuth struct { - // The account id the user is accessing - AccountId string - // The account domain - Domain string - // The account domain category, TBC values - DomainCategory string - // Indicates whether this user was invited, TBC logic - Invited bool - // Indicates whether this is a child account - IsChild bool - - // The user id - UserId string - // Last login time for this user - LastLogin time.Time - // The Groups the user belongs to on this account - Groups []string - - // Indicates whether this user has authenticated with a Personal Access Token - IsPAT bool -} - -func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { +func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) { return GetUserAuthFromContext(r.Context()) } -func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { +func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request { return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) } -func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { - if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { +func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok { return userAuth, nil } - return UserAuth{}, fmt.Errorf("user auth not in context") + return auth.UserAuth{}, fmt.Errorf("user auth not in context") } -func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { +func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context { //nolint ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId) //nolint diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 356a2f640..6b7a36c20 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -224,7 +224,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) - return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 4b9b79fdc..c5c48ef32 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, AccountId: accountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 08a0b2afd..67638aea5 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 42b519c29..a027c067e 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, AccountId: testingDNSSettingsAccount.Id, Domain: testingDNSSettingsAccount.Domain, diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index d49b6c7e0..4716782f3 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", AccountId: testNSGroupAccountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index a0695fa3f..923a24e31 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -14,11 +14,12 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_account", diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e861e873c..208a2e828 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -11,10 +11,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 34694ec8c..b7dd3944a 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -19,12 +19,13 @@ import ( "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index d7b598a5d..f99eca794 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -12,15 +12,15 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/shared/management/status" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) // handler is a handler that returns networks of the account diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 59396dceb..c31729a39 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -8,10 +8,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type resourceHandler struct { diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 2e64c637f..c311a29fe 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -7,10 +7,10 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type routersHandler struct { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 7a5a6d911..ddf2e2a70 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -21,6 +21,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" @@ -296,7 +297,7 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "admin_user", Domain: "hotmail.com", AccountId: "test_id", @@ -444,7 +445,7 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", @@ -527,7 +528,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody))) req.Header.Set("Content-Type", "application/json") - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index cedd5ac88..094a36e38 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -16,12 +16,13 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/util" ) @@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index cb6995793..a2d656a47 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -9,11 +9,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 4d6bad5e3..ab1639ab1 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index fd39ae2a3..ca5a0a6ab 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initPoliciesTestData(policies ...*types.Policy) *handler { @@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 3ebc4d1e1..744cde10b 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index c644b533a..8c60d6fe8 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -16,9 +16,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 466a7987f..a44d81e3e 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" @@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: testAccountID, diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 2287dadfe..d267b6eea 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 7b46b486b..b137b6dd1 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -15,10 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, Domain: "hotmail.com", AccountId: "testAccountId", diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index bae07af4a..867db3ca9 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -8,10 +8,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 92544c56d..7cda14468 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -17,10 +17,11 @@ import ( "github.com/netbirdio/netbird/management/server/util" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index e08004218..37f0a6c1d 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -128,7 +129,7 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") @@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) { tt := []struct { name string expectedStatus int - requestAuth nbcontext.UserAuth + requestAuth auth.UserAuth expectedResult *api.User }{ { @@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) { }, { name: "user not found", - requestAuth: nbcontext.UserAuth{UserId: "not-found"}, + requestAuth: auth.UserAuth{UserId: "not-found"}, expectedStatus: http.StatusNotFound, }, { name: "not of account", - requestAuth: nbcontext.UserAuth{UserId: "not-of-account"}, + requestAuth: auth.UserAuth{UserId: "not-of-account"}, expectedStatus: http.StatusForbidden, }, { name: "blocked user", - requestAuth: nbcontext.UserAuth{UserId: "blocked-user"}, + requestAuth: auth.UserAuth{UserId: "blocked-user"}, expectedStatus: http.StatusForbidden, }, { name: "service user", - requestAuth: nbcontext.UserAuth{UserId: "service-user"}, + requestAuth: auth.UserAuth{UserId: "service-user"}, expectedStatus: http.StatusForbidden, }, { name: "owner", - requestAuth: nbcontext.UserAuth{UserId: "owner"}, + requestAuth: auth.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "owner", @@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "regular user", - requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, + requestAuth: auth.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "regular-user", @@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "admin user", - requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, + requestAuth: auth.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "admin-user", @@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "restricted user", - requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + requestAuth: auth.UserAuth{UserId: "restricted-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "restricted-user", @@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) { req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } @@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) { req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index bce917a25..9439165a4 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -10,22 +10,23 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/auth" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) -type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) -type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error -type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager + authManager serverauth.Manager ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc @@ -34,7 +35,7 @@ type AuthMiddleware struct { // NewAuthMiddleware instance constructor func NewAuthMiddleware( - authManager auth.Manager, + authManager serverauth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, @@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return } - auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := strings.ToLower(auth[0]) + authHeader := strings.Split(r.Header.Get("Authorization"), " ") + authType := strings.ToLower(authHeader[0]) // fallback to token when receive pat as bearer - if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") { authType = "token" - auth[0] = authType + authHeader[0] = authType } switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, auth) + request, err := m.checkJWTFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) @@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, request) case "token": - request, err := m.checkPATFromRequest(r, auth) + request, err := m.checkPATFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized @@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromJWTRequest(auth) +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { @@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromPATRequest(auth) +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { return r, fmt.Errorf("error extracting token: %w", err) } @@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, err } - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: user.Id, AccountId: user.AccountID, Domain: accDomain, diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d1bd9959f..7badc03e4 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -12,11 +12,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) { if token == JWT { - return nbcontext.UserAuth{ + return nbauth.UserAuth{ UserId: userID, AccountId: accountID, Domain: testAccount.Domain, @@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA Valid: true, }, nil } - return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") + return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { @@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error { return fmt.Errorf("Should never get reached") } -func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } @@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, nil, @@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name string path string authHeader string - expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status + expectedUserAuth *nbauth.UserAuth // nil expects 401 response status }{ { name: "Valid PAT Token", path: "/test", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token", path: "/test", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token with child", path: "/test?account=xyz", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, nil, diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index ab3f5437a..ac165aeb2 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -18,8 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" + serverauth "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" ) func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { @@ -71,14 +72,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) - am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, MarkPATUsedFunc: authManager.MarkPATUsed, @@ -123,8 +124,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m } } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} +func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { + userAuth := auth.UserAuth{} switch token { case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go index 49075a0d3..126a76919 100644 --- a/management/server/idp/pocketid_test.go +++ b/management/server/idp/pocketid_test.go @@ -1,138 +1,137 @@ package idp import ( - "context" - "testing" + "context" + "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/telemetry" ) - func TestNewPocketIdManager(t *testing.T) { - type test struct { - name string - inputConfig PocketIdClientConfig - assertErrFunc require.ErrorAssertionFunc - assertErrFuncMessage string - } + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } - defaultTestConfig := PocketIdClientConfig{ - APIToken: "api_token", - ManagementEndpoint: "http://localhost", - } + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } - tests := []test{ - { - name: "Good Configuration", - inputConfig: defaultTestConfig, - assertErrFunc: require.NoError, - assertErrFuncMessage: "shouldn't return error", - }, - { - name: "Missing ManagementEndpoint", - inputConfig: PocketIdClientConfig{ - APIToken: defaultTestConfig.APIToken, - ManagementEndpoint: "", - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - { - name: "Missing APIToken", - inputConfig: PocketIdClientConfig{ - APIToken: "", - ManagementEndpoint: defaultTestConfig.ManagementEndpoint, - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - } + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) - tc.assertErrFunc(t, err, tc.assertErrFuncMessage) - }) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } } func TestPocketID_GetUserDataByID(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - md := AppMetadata{WTAccountID: "acc1"} - got, err := mgr.GetUserDataByID(context.Background(), "u1", md) - require.NoError(t, err) - assert.Equal(t, "u1", got.ID) - assert.Equal(t, "user1@example.com", got.Email) - assert.Equal(t, "User One", got.Name) - assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) } func TestPocketID_GetAccount_WithPagination(t *testing.T) { - // Single page response with two users - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - users, err := mgr.GetAccount(context.Background(), "accX") - require.NoError(t, err) - require.Len(t, users, 2) - assert.Equal(t, "u1", users[0].ID) - assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) - assert.Equal(t, "u2", users[1].ID) + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) } func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - accounts, err := mgr.GetAllAccounts(context.Background()) - require.NoError(t, err) - require.Len(t, accounts[UnsetAccountID], 2) + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) } func TestPocketID_CreateUser(t *testing.T) { - client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") - require.NoError(t, err) - assert.Equal(t, "newid", ud.ID) - assert.Equal(t, "new@example.com", ud.Email) - assert.Equal(t, "New User", ud.Name) - assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) - if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { - assert.True(t, *ud.AppMetadata.WTPendingInvite) - } - assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) } func TestPocketID_InviteAndDeleteUser(t *testing.T) { - // Same mock for both calls; returns OK with empty JSON - client := &mockHTTPClient{code: 200, resBody: `{}`} + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - err = mgr.InviteUserByID(context.Background(), "u1") - require.NoError(t, err) + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) - err = mgr.DeleteUser(context.Background(), "u1") - require.NoError(t, err) + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index fc67e01af..496be9caa 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -364,7 +364,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := BuildManager(ctx, store, networkMapController, nil, "", + accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/management/server/management_test.go b/management/server/management_test.go index 930ecfb5a..c485f16b4 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -209,6 +209,7 @@ func startServer( accountManager, err := server.BuildManager( context.Background(), + nil, str, networkMapController, nil, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 781d84f5f..0178e51f5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,6 +2,7 @@ package mock_server import ( "context" + "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" @@ -12,7 +13,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -34,7 +34,7 @@ type MockAccountManager struct { GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error @@ -84,7 +84,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -119,7 +119,7 @@ type MockAccountManager struct { GetStoreFunc func() store.Store UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) @@ -470,7 +470,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { if am.GetUserFromUserAuthFunc != nil { return am.GetUserFromUserAuthFunc(ctx, userAuth) } @@ -675,7 +675,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) } @@ -937,7 +937,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") } -func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") } @@ -969,7 +969,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { return am.GetCurrentUserInfoFunc(ctx, userAuth) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 35291b30c..51738c106 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -793,7 +793,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index c6cec6f7e..e2dea2c6b 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -10,8 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be858..6b8cf9412 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -8,11 +8,11 @@ import ( "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 8054d05c6..6be90baa7 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 72b15fd9a..e90c61a97 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -5,8 +5,8 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) type NetworkRouter struct { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 95c609595..21a6952a9 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1292,7 +1292,7 @@ func Test_RegisterPeerByUser(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1377,7 +1377,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1530,7 +1530,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1610,7 +1610,7 @@ func Test_LoginPeer(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index d65dc5045..f0bbbc32e 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,8 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/route_test.go b/management/server/route_test.go index 27fe033c8..7ff362bc6 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1292,7 +1292,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 6eb391cb5..da29e1d87 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -1,8 +1,8 @@ package types import ( - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // RouteFirewallRule a firewall rule applicable for a routed network. diff --git a/management/server/user.go b/management/server/user.go index be4e491a8..6b8bcbcad 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -7,12 +7,13 @@ import ( "strings" "time" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - nbContext "github.com/netbirdio/netbird/management/server/context" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -175,9 +176,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } -// GetUser looks up a user by provided nbContext.UserAuths. +// GetUser looks up a user by provided auth.UserAuths. // Expects account to have been created already. -func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return nil, err @@ -970,7 +971,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou var peerIDs []string for _, peer := range peers { // nolint:staticcheck - ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key) if peer.UserID == "" { // we do not want to expire peers that are added via setup key @@ -1214,7 +1215,7 @@ func validateUserInvite(invite *types.UserInfo) error { } // GetCurrentUserInfo retrieves the account's current user info and permissions -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { accountID, userID := userAuth.AccountId, userAuth.UserId user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 69b8c85ee..5ce15621e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -11,12 +11,12 @@ import ( "golang.org/x/exp/maps" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { permissionsManager: permissionsManager, } - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: mockUserID, AccountId: mockAccountID, } @@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - userAuth nbcontext.UserAuth + userAuth auth.UserAuth expectedErr error expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { name: "owner user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account1Owner", @@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "regular user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "regular-user", @@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "admin user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "admin-user", @@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked regular user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { { name: "settings blocked regular user child account", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked owner user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account2Owner", diff --git a/relay/server/peer.go b/relay/server/peer.go index c47f2e960..c5ff41857 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -9,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/relay/healthcheck" - "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" + "github.com/netbirdio/netbird/shared/relay/healthcheck" + "github.com/netbirdio/netbird/shared/relay/messages" ) const ( diff --git a/management/server/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go similarity index 92% rename from management/server/auth/jwt/extractor.go rename to shared/auth/jwt/extractor.go index d270d0ff1..a41d5f07a 100644 --- a/management/server/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -8,7 +8,7 @@ import ( "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" ) const ( @@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string { return url } -func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { +// ToUserAuth extracts user authentication information from a JWT token +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { claims := token.Claims.(jwt.MapClaims) - userAuth := nbcontext.UserAuth{} + userAuth := auth.UserAuth{} userID, ok := claims[c.userIDClaim].(string) if !ok { @@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro return userAuth, nil } +// ToGroups extracts group information from a JWT token func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { claims := token.Claims.(jwt.MapClaims) userJWTGroups := make([]string, 0) diff --git a/management/server/auth/jwt/validator.go b/shared/auth/jwt/validator.go similarity index 100% rename from management/server/auth/jwt/validator.go rename to shared/auth/jwt/validator.go diff --git a/shared/auth/user.go b/shared/auth/user.go new file mode 100644 index 000000000..c1bae808e --- /dev/null +++ b/shared/auth/user.go @@ -0,0 +1,28 @@ +package auth + +import ( + "time" +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} diff --git a/shared/context/keys.go b/shared/context/keys.go index 5345ee214..c5b5da044 100644 --- a/shared/context/keys.go +++ b/shared/context/keys.go @@ -5,4 +5,4 @@ const ( AccountIDKey = "accountID" UserIDKey = "userID" PeerIDKey = "peerID" -) \ No newline at end of file +) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d3f341529..f98e76ce7 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -117,7 +118,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } diff --git a/shared/management/operations/operation.go b/shared/management/operations/operation.go index b9b500362..b1ba12815 100644 --- a/shared/management/operations/operation.go +++ b/shared/management/operations/operation.go @@ -1,4 +1,4 @@ package operations // Operation represents a permission operation type -type Operation string \ No newline at end of file +type Operation string diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 0de00ec0c..ca12cf48c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,19 +1,18 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.0 +// protoc v6.32.1 // source: management.proto package proto import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" ) const ( @@ -268,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23, 0} + return file_management_proto_rawDescGZIP(), []int{24, 0} } type EncryptedMessage struct { @@ -799,16 +798,21 @@ type Flags struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` - DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` - DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` - BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` - BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` + DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` + DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` + BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` + BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + EnableSSHRoot bool `protobuf:"varint,11,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,12,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` } func (x *Flags) Reset() { @@ -913,6 +917,41 @@ func (x *Flags) GetLazyConnectionEnabled() bool { return false } +func (x *Flags) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *Flags) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *Flags) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *Flags) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *Flags) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -1273,6 +1312,7 @@ type NetbirdConfig struct { Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` + Jwt *JWTConfig `protobuf:"bytes,6,opt,name=jwt,proto3" json:"jwt,omitempty"` } func (x *NetbirdConfig) Reset() { @@ -1342,6 +1382,13 @@ func (x *NetbirdConfig) GetFlow() *FlowConfig { return nil } +func (x *NetbirdConfig) GetJwt() *JWTConfig { + if x != nil { + return x.Jwt + } + return nil +} + // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) type HostConfig struct { state protoimpl.MessageState @@ -1568,6 +1615,78 @@ func (x *FlowConfig) GetDnsCollection() bool { return false } +// JWTConfig represents JWT authentication configuration +type JWTConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Issuer string `protobuf:"bytes,1,opt,name=issuer,proto3" json:"issuer,omitempty"` + Audience string `protobuf:"bytes,2,opt,name=audience,proto3" json:"audience,omitempty"` + KeysLocation string `protobuf:"bytes,3,opt,name=keysLocation,proto3" json:"keysLocation,omitempty"` + MaxTokenAge int64 `protobuf:"varint,4,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` +} + +func (x *JWTConfig) Reset() { + *x = JWTConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *JWTConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*JWTConfig) ProtoMessage() {} + +func (x *JWTConfig) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[17] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use JWTConfig.ProtoReflect.Descriptor instead. +func (*JWTConfig) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{17} +} + +func (x *JWTConfig) GetIssuer() string { + if x != nil { + return x.Issuer + } + return "" +} + +func (x *JWTConfig) GetAudience() string { + if x != nil { + return x.Audience + } + return "" +} + +func (x *JWTConfig) GetKeysLocation() string { + if x != nil { + return x.KeysLocation + } + return "" +} + +func (x *JWTConfig) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers type ProtectedHostConfig struct { @@ -1583,7 +1702,7 @@ type ProtectedHostConfig struct { func (x *ProtectedHostConfig) Reset() { *x = ProtectedHostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1596,7 +1715,7 @@ func (x *ProtectedHostConfig) String() string { func (*ProtectedHostConfig) ProtoMessage() {} func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1609,7 +1728,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead. func (*ProtectedHostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{17} + return file_management_proto_rawDescGZIP(), []int{18} } func (x *ProtectedHostConfig) GetHostConfig() *HostConfig { @@ -1656,7 +1775,7 @@ type PeerConfig struct { func (x *PeerConfig) Reset() { *x = PeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1669,7 +1788,7 @@ func (x *PeerConfig) String() string { func (*PeerConfig) ProtoMessage() {} func (x *PeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1682,7 +1801,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. func (*PeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{18} + return file_management_proto_rawDescGZIP(), []int{19} } func (x *PeerConfig) GetAddress() string { @@ -1770,7 +1889,7 @@ type NetworkMap struct { func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1783,7 +1902,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1796,7 +1915,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{19} + return file_management_proto_rawDescGZIP(), []int{20} } func (x *NetworkMap) GetSerial() uint64 { @@ -1904,7 +2023,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1917,7 +2036,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1930,7 +2049,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -1978,13 +2097,14 @@ type SSHConfig struct { SshEnabled bool `protobuf:"varint,1,opt,name=sshEnabled,proto3" json:"sshEnabled,omitempty"` // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. - SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + JwtConfig *JWTConfig `protobuf:"bytes,3,opt,name=jwtConfig,proto3" json:"jwtConfig,omitempty"` } func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1997,7 +2117,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2010,7 +2130,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{22} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2027,6 +2147,13 @@ func (x *SSHConfig) GetSshPubKey() []byte { return nil } +func (x *SSHConfig) GetJwtConfig() *JWTConfig { + if x != nil { + return x.JwtConfig + } + return nil +} + // DeviceAuthorizationFlowRequest empty struct for future expansion type DeviceAuthorizationFlowRequest struct { state protoimpl.MessageState @@ -2037,7 +2164,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2050,7 +2177,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2063,7 +2190,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{23} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2082,7 +2209,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2095,7 +2222,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2108,7 +2235,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2135,7 +2262,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2148,7 +2275,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2161,7 +2288,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{25} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2178,7 +2305,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2191,7 +2318,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2204,7 +2331,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{26} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2250,7 +2377,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2263,7 +2390,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2276,7 +2403,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *ProviderConfig) GetClientID() string { @@ -2384,7 +2511,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2397,7 +2524,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2410,7 +2537,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{28} } func (x *Route) GetID() string { @@ -2498,7 +2625,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2511,7 +2638,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2524,7 +2651,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2568,7 +2695,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2581,7 +2708,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2594,7 +2721,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *CustomZone) GetDomain() string { @@ -2627,7 +2754,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2640,7 +2767,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2653,7 +2780,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *SimpleRecord) GetName() string { @@ -2706,7 +2833,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2719,7 +2846,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2732,7 +2859,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2777,7 +2904,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2790,7 +2917,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2803,7 +2930,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *NameServer) GetIP() string { @@ -2846,7 +2973,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2859,7 +2986,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2872,7 +2999,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *FirewallRule) GetPeerIP() string { @@ -2936,7 +3063,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2949,7 +3076,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2962,7 +3089,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *NetworkAddress) GetNetIP() string { @@ -2990,7 +3117,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3003,7 +3130,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3016,7 +3143,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *Checks) GetFiles() []string { @@ -3041,7 +3168,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3054,7 +3181,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3067,7 +3194,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{37} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3138,7 +3265,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3151,7 +3278,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3164,7 +3291,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3255,7 +3382,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3268,7 +3395,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3281,7 +3408,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3324,7 +3451,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3337,7 +3464,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3350,7 +3477,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36, 0} + return file_management_proto_rawDescGZIP(), []int{37, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3438,7 +3565,7 @@ var file_management_proto_rawDesc = []byte{ 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, - 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xc1, 0x03, 0x0a, 0x05, + 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, @@ -3466,435 +3593,465 @@ var file_management_proto_rawDesc = []byte{ 0x63, 0x6b, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, - 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, - 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, - 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, - 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, - 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, - 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, - 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, - 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, - 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, - 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, - 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, - 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, - 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, - 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, - 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, - 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, - 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, - 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, - 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, - 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, - 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, - 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, - 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, - 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, - 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, - 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, - 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, - 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, - 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, - 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, - 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, - 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, - 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, - 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x7d, 0x0a, 0x13, 0x50, - 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, - 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, - 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, - 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, - 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, - 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, - 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, - 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, - 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, - 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, - 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x6f, 0x6f, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, + 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x12, 0x42, 0x0a, 0x1c, 0x65, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x6f, 0x72, + 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x1c, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, + 0x44, 0x0a, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04, + 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, + 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, + 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, + 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, + 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, + 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, + 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, + 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, + 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, + 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, + 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, + 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, + 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, + 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, + 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, + 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xa8, 0x02, + 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, + 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, + 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, + 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x12, + 0x27, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, + 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, + 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, + 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, + 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, + 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, + 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, + 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, + 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, + 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, + 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, + 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, + 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, + 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, + 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, + 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, + 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, + 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, + 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, + 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, + 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, + 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, + 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, + 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, + 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, + 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, + 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, + 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, + 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, + 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, - 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, - 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, - 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, - 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, - 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, - 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, - 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, - 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24, - 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, - 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, - 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, - 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, - 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, - 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, - 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, - 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, - 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, - 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, - 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, - 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, - 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, + 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, + 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, + 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, + 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, + 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, + 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, + 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, + 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, + 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, + 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, + 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, + 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, + 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, + 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, + 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, + 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, + 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, + 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, + 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, + 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, + 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, + 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, + 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, + 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, + 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, + 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, + 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, + 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3910,7 +4067,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 40) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 41) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -3934,107 +4091,110 @@ var file_management_proto_goTypes = []interface{}{ (*HostConfig)(nil), // 19: management.HostConfig (*RelayConfig)(nil), // 20: management.RelayConfig (*FlowConfig)(nil), // 21: management.FlowConfig - (*ProtectedHostConfig)(nil), // 22: management.ProtectedHostConfig - (*PeerConfig)(nil), // 23: management.PeerConfig - (*NetworkMap)(nil), // 24: management.NetworkMap - (*RemotePeerConfig)(nil), // 25: management.RemotePeerConfig - (*SSHConfig)(nil), // 26: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 27: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 28: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 29: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 30: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 31: management.ProviderConfig - (*Route)(nil), // 32: management.Route - (*DNSConfig)(nil), // 33: management.DNSConfig - (*CustomZone)(nil), // 34: management.CustomZone - (*SimpleRecord)(nil), // 35: management.SimpleRecord - (*NameServerGroup)(nil), // 36: management.NameServerGroup - (*NameServer)(nil), // 37: management.NameServer - (*FirewallRule)(nil), // 38: management.FirewallRule - (*NetworkAddress)(nil), // 39: management.NetworkAddress - (*Checks)(nil), // 40: management.Checks - (*PortInfo)(nil), // 41: management.PortInfo - (*RouteFirewallRule)(nil), // 42: management.RouteFirewallRule - (*ForwardingRule)(nil), // 43: management.ForwardingRule - (*PortInfo_Range)(nil), // 44: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 45: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 46: google.protobuf.Duration + (*JWTConfig)(nil), // 22: management.JWTConfig + (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig + (*PeerConfig)(nil), // 24: management.PeerConfig + (*NetworkMap)(nil), // 25: management.NetworkMap + (*RemotePeerConfig)(nil), // 26: management.RemotePeerConfig + (*SSHConfig)(nil), // 27: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 28: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 29: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 30: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 31: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 32: management.ProviderConfig + (*Route)(nil), // 33: management.Route + (*DNSConfig)(nil), // 34: management.DNSConfig + (*CustomZone)(nil), // 35: management.CustomZone + (*SimpleRecord)(nil), // 36: management.SimpleRecord + (*NameServerGroup)(nil), // 37: management.NameServerGroup + (*NameServer)(nil), // 38: management.NameServer + (*FirewallRule)(nil), // 39: management.FirewallRule + (*NetworkAddress)(nil), // 40: management.NetworkAddress + (*Checks)(nil), // 41: management.Checks + (*PortInfo)(nil), // 42: management.PortInfo + (*RouteFirewallRule)(nil), // 43: management.RouteFirewallRule + (*ForwardingRule)(nil), // 44: management.ForwardingRule + (*PortInfo_Range)(nil), // 45: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 46: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 47: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 25, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 24, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 40, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 26, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 25, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 41, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 39, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 40, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 40, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 45, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 41, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 46, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 22, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 46, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 26, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 23, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 25, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 32, // 28: management.NetworkMap.Routes:type_name -> management.Route - 33, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 25, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 38, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 42, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 43, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 26, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 4, // 35: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 31, // 36: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 31, // 37: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 36, // 38: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 34, // 39: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 35, // 40: management.CustomZone.Records:type_name -> management.SimpleRecord - 37, // 41: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 42: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 43: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 44: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 41, // 45: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 44, // 46: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 47: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 48: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 41, // 49: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 50: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 41, // 51: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 41, // 52: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 53: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 54: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 55: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 56: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 68: management.ManagementService.Logout:output_type -> management.Empty - 61, // [61:69] is the sub-list for method output_type - 53, // [53:61] is the sub-list for method input_type - 53, // [53:53] is the sub-list for extension type_name - 53, // [53:53] is the sub-list for extension extendee - 0, // [0:53] is the sub-list for field type_name + 22, // 22: management.NetbirdConfig.jwt:type_name -> management.JWTConfig + 3, // 23: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 47, // 24: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 19, // 25: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 27, // 26: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 26, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 33, // 29: management.NetworkMap.Routes:type_name -> management.Route + 34, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 26, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 39, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 43, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 44, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 32, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 32, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 37, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 35, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 36, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord + 38, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 42, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 45, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 42, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 42, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 42, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 70: management.ManagementService.Logout:output_type -> management.Empty + 63, // [63:71] is the sub-list for method output_type + 55, // [55:63] is the sub-list for method input_type + 55, // [55:55] is the sub-list for extension type_name + 55, // [55:55] is the sub-list for extension extendee + 0, // [0:55] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4248,7 +4408,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*JWTConfig); i { case 0: return &v.state case 1: @@ -4260,7 +4420,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -4272,7 +4432,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -4284,7 +4444,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -4296,7 +4456,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4308,7 +4468,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4320,7 +4480,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4332,7 +4492,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4344,7 +4504,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4356,7 +4516,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4368,7 +4528,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4380,7 +4540,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4392,7 +4552,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4404,7 +4564,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4416,7 +4576,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4428,7 +4588,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4440,7 +4600,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4452,7 +4612,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4464,7 +4624,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4476,7 +4636,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -4488,7 +4648,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*PortInfo); i { case 0: return &v.state case 1: @@ -4500,7 +4660,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*RouteFirewallRule); i { case 0: return &v.state case 1: @@ -4512,6 +4672,18 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4524,7 +4696,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[36].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[37].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4534,7 +4706,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 40, + NumMessages: 41, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 3982ea2af..16737cf58 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -146,6 +146,12 @@ message Flags { bool blockInbound = 9; bool lazyConnectionEnabled = 10; + + bool enableSSHRoot = 11; + bool enableSSHSFTP = 12; + bool enableSSHLocalPortForwarding = 13; + bool enableSSHRemotePortForwarding = 14; + bool disableSSHAuth = 15; } // PeerSystemMeta is machine meta data like OS and version. @@ -202,6 +208,8 @@ message NetbirdConfig { RelayConfig relay = 4; FlowConfig flow = 5; + + JWTConfig jwt = 6; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) @@ -240,6 +248,14 @@ message FlowConfig { bool dnsCollection = 8; } +// JWTConfig represents JWT authentication configuration +message JWTConfig { + string issuer = 1; + string audience = 2; + string keysLocation = 3; + int64 maxTokenAge = 4; +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers message ProtectedHostConfig { @@ -335,6 +351,8 @@ message SSHConfig { // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. bytes sshPubKey = 2; + + JWTConfig jwtConfig = 3; } // DeviceAuthorizationFlowRequest empty struct for future expansion diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 967e18d79..c057ef089 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -11,8 +11,8 @@ import ( "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" - quictls "github.com/netbirdio/netbird/shared/relay/tls" nbnet "github.com/netbirdio/netbird/client/net" + quictls "github.com/netbirdio/netbird/shared/relay/tls" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 66fff3447..37b189e05 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -14,9 +14,9 @@ import ( "github.com/coder/websocket" log "github.com/sirupsen/logrus" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/constants.go b/shared/relay/constants.go index 3c7c3cd29..0f2a27610 100644 --- a/shared/relay/constants.go +++ b/shared/relay/constants.go @@ -3,4 +3,4 @@ package relay const ( // WebSocketURLPath is the path for the websocket relay connection WebSocketURLPath = "/relay" -) \ No newline at end of file +) diff --git a/version/url_windows.go b/version/url_windows.go index 14fdb7ae6..a0fb6e5dd 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -6,7 +6,7 @@ import ( ) const ( - urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExe = "https://pkgs.netbird.io/windows/x64" urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) @@ -18,11 +18,11 @@ func DownloadUrl() string { if err != nil { return downloadURL } - + url := urlWinExe if runtime.GOARCH == "arm64" { url = urlWinExeArm } - + return url } From 4eeb2d8debcb7c6e2c88f9c8fbc94958f8a9f836 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 17 Nov 2025 18:20:30 +0100 Subject: [PATCH 060/118] [management] added exception on not appending route firewall rules if we have all wildcard (#4801) --- management/server/types/networkmapbuilder.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 41aaa7fc8..6361e2e93 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -22,10 +22,11 @@ import ( ) const ( - allPeers = "0.0.0.0" - fw = "fw:" - rfw = "route-fw:" - nr = "network-resource-" + allPeers = "0.0.0.0" + allWildcard = "0.0.0.0/0" + v6AllWildcard = "::/0" + fw = "fw:" + rfw = "route-fw:" ) type NetworkMapCache struct { @@ -1640,6 +1641,10 @@ func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, } if string(rule.RouteID) == update.RuleID { + if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { + break + } + sourceIP := update.AddSourceIP if strings.Contains(sourceIP, ":") { From 60f4d5f9b0ab6058ccfc7805c6143d1292b4bcec Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:41:17 +0100 Subject: [PATCH 061/118] [client] Revert migrate deprecated grpc client code #4805 --- client/grpc/dialer.go | 42 +++++++----------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7763f2417..54966b50e 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "runtime" "time" @@ -12,7 +11,6 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -20,9 +18,6 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) -// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready -var ErrConnectionShutdown = errors.New("connection shutdown before ready") - // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -// waitForConnectionReady blocks until the connection becomes ready or fails. -// Returns an error if the connection times out, is cancelled, or enters shutdown state. -func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { - conn.Connect() - - state := conn.GetState() - for state != connectivity.Ready && state != connectivity.Shutdown { - if !conn.WaitForStateChange(ctx, state) { - return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) - } - state = conn.GetState() - } - - if state == connectivity.Shutdown { - return ErrConnectionShutdown - } - - return nil -} - // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - conn, err := grpc.NewClient( + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, addr, transportOption, WithCustomDialer(tlsEnabled, component), + grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - return nil, fmt.Errorf("new client: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - if err := waitForConnectionReady(ctx, conn); err != nil { - _ = conn.Close() - return nil, err + return nil, fmt.Errorf("dial context: %w", err) } return conn, nil From 05cbead39b21c5b2c69a2df020ede0746058faad Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:15:57 +0100 Subject: [PATCH 062/118] [management] Fix direct peer networks route (#4802) --- management/server/types/account.go | 11 ++ management/server/types/networkmapbuilder.go | 152 +++++++++++++++---- 2 files changed, 130 insertions(+), 33 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 8797e1fa3..51313819f 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1242,6 +1242,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID } } } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) for pID := range distPeersWithPolicy { @@ -1587,6 +1594,10 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st sourcePeers[peer] = struct{}{} } } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } } } diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 6361e2e93..5790f1646 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -858,6 +858,14 @@ func (b *NetworkMapBuilder) getRulePeers( } } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) for pID := range distPeersWithPolicy { peer := b.cache.globalPeers[pID] @@ -1287,24 +1295,54 @@ func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPee if !rule.Enabled { continue } + var peerInSources, peerInDestinations bool - peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) - peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { + peerInSources = true + } else { + peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { + peerInDestinations = true + } else { + peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) + } if peerInSources { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } } if peerInDestinations { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } } if rule.Bidirectional { if peerInSources { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } } if peerInDestinations { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } } } } @@ -1566,42 +1604,90 @@ func (b *NetworkMapBuilder) addUpdateForPeersInGroups( if _, ok := b.validatedPeers[peerID]; !ok { continue } - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - AddConnectedPeer: newPeerID, - AddFirewallRules: make([]*FirewallRuleDelta, 0), - } - updates[peerID] = delta + targetPeer := b.cache.globalPeers[peerID] + if targetPeer == nil { + continue } + peerIPForRule := fr.PeerIP if all { - fr.PeerIP = allPeers + peerIPForRule = allPeers } - if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { - expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) - for _, expandedRule := range expandedRules { - ruleID := b.generateFirewallRuleID(expandedRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: expandedRule, - RuleID: ruleID, - Direction: direction, - }) - } - } else { - ruleID := b.generateFirewallRuleID(fr) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: fr, - RuleID: ruleID, - Direction: direction, - }) - } + b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) } } } +func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, +) { + if targetPeerID == newPeerID { + return + } + + if _, ok := b.validatedPeers[targetPeerID]; !ok { + return + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return + } + + targetPeer := b.cache.globalPeers[targetPeerID] + if targetPeer == nil { + return + } + + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) +} + +func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, +) { + delta := updates[targetPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: targetPeerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[targetPeerID] = delta + } + + baseRule.PeerIP = peerIP + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(baseRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: baseRule, + RuleID: ruleID, + Direction: direction, + }) + } +} + func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { if aclView := b.cache.peerACLs[peerID]; aclView != nil { From 3351b38434c51151a03b6e4f0dc023b8bbb28f59 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 19 Nov 2025 11:52:18 +0100 Subject: [PATCH 063/118] [management] pass config to controller (#4807) --- client/cmd/testutil_test.go | 2 +- client/internal/engine.go | 1 + client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- .../network_map/controller/controller.go | 9 +- management/internals/server/controllers.go | 2 +- .../internals/shared/grpc/conversion.go | 30 +- management/internals/shared/grpc/server.go | 4 +- management/server/account_test.go | 3 +- management/server/dns_test.go | 3 +- .../testing/testing_tools/channel/channel.go | 3 +- management/server/management_proto_test.go | 2 +- management/server/management_test.go | 2 +- management/server/nameserver_test.go | 3 +- management/server/peer_test.go | 12 +- management/server/route_test.go | 3 +- shared/management/client/client_test.go | 2 +- shared/management/proto/management.pb.go | 849 +++++++++--------- shared/management/proto/management.proto | 2 - 19 files changed, 464 insertions(+), 472 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index e7b0279e8..5477653a2 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -116,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 1deb3d3cf..94c948398 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1192,6 +1192,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) if forwarderPort == 0 { forwarderPort = nbdns.ForwarderClientPort diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3b7ff0eba..d6ab7391e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1624,7 +1624,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) - networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err diff --git a/client/server/server_test.go b/client/server/server_test.go index 96d4c0af0..f8592bc7a 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -316,7 +316,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) - networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index ad25494c7..49bb9cef3 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -19,6 +19,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" @@ -47,6 +48,7 @@ type Controller struct { updateAccountPeersBufferInterval atomic.Int64 // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string + config *config.Config requestBuffer account.RequestBuffer @@ -68,7 +70,7 @@ type bufferUpdate struct { var _ network_map.Controller = (*Controller)(nil) -func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller { +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller { nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) if err != nil { log.Fatal(fmt.Errorf("error creating metrics: %w", err)) @@ -95,6 +97,7 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App integratedPeerValidator: integratedPeerValidator, settingsManager: settingsManager, dnsDomain: dnsDomain, + config: config, proxyController: proxyController, @@ -205,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin peerGroups := account.GetPeerGroups(p.ID) start = time.Now() - update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) @@ -323,7 +326,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe peerGroups := account.GetPeerGroups(peerId) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) - update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) return nil diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index b61e33688..38ec6fde6 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -70,7 +70,7 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager { func (s *BaseServer) NetworkMapController() network_map.Controller { return Create(s, func() *nmapcontroller.Controller { - return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController()) + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config) }) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 7f64034df..148db4e37 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -83,7 +83,7 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) @@ -92,7 +92,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } if peer.SSHEnabled { - sshConfig.JwtConfig = buildJWTConfig(config) + sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } return &proto.PeerConfig{ @@ -104,9 +104,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } -func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -363,35 +363,29 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } // buildJWTConfig constructs JWT configuration for SSH servers from management server config -func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig { - if config == nil { +func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig { + if config == nil || config.AuthAudience == "" { return nil } - if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" { - return nil - } - - issuer := strings.TrimSpace(config.HttpConfig.AuthIssuer) - if issuer == "" { - if config.DeviceAuthorizationFlow != nil { - if d := deriveIssuerFromTokenEndpoint(config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint); d != "" { - issuer = d - } + issuer := strings.TrimSpace(config.AuthIssuer) + if issuer == "" || deviceFlowConfig != nil { + if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" { + issuer = d } } if issuer == "" { return nil } - keysLocation := strings.TrimSpace(config.HttpConfig.AuthKeysLocation) + keysLocation := strings.TrimSpace(config.AuthKeysLocation) if keysLocation == "" { keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" } return &proto.JWTConfig{ Issuer: issuer, - Audience: config.HttpConfig.AuthAudience, + Audience: config.AuthAudience, KeysLocation: keysLocation, } } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 4364272a0..0aadadf84 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -646,7 +646,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow), Checks: toProtocolChecks(ctx, postureChecks), } @@ -713,7 +713,7 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index 10d718bbf..340e8db18 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -2958,7 +2959,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 6b7a36c20..99b09566a 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -12,6 +12,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -222,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index ac165aeb2..e292a7d6c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -71,7 +72,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{}) am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 496be9caa..42311d944 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -363,7 +363,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/management/server/management_test.go b/management/server/management_test.go index c485f16b4..2350b225b 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -205,7 +205,7 @@ func startServer( ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager( context.Background(), diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 51738c106..a4574c978 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -13,6 +13,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -791,7 +792,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 21a6952a9..2d09f5200 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1058,7 +1058,6 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel - assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } @@ -1177,7 +1176,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) + response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1228,6 +1227,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + //nolint assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) @@ -1290,7 +1290,7 @@ func Test_RegisterPeerByUser(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1375,7 +1375,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1528,7 +1528,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1608,7 +1608,7 @@ func Test_LoginPeer(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) diff --git a/management/server/route_test.go b/management/server/route_test.go index 7ff362bc6..5c8b636bc 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1290,7 +1291,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index f98e76ce7..9e08317f6 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index ca12cf48c..45211b7f4 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.1 +// protoc v6.33.0 // source: management.proto package proto @@ -1312,7 +1312,6 @@ type NetbirdConfig struct { Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` - Jwt *JWTConfig `protobuf:"bytes,6,opt,name=jwt,proto3" json:"jwt,omitempty"` } func (x *NetbirdConfig) Reset() { @@ -1382,13 +1381,6 @@ func (x *NetbirdConfig) GetFlow() *FlowConfig { return nil } -func (x *NetbirdConfig) GetJwt() *JWTConfig { - if x != nil { - return x.Jwt - } - return nil -} - // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) type HostConfig struct { state protoimpl.MessageState @@ -2619,7 +2611,8 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` - ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` + // Deprecated: Do not use. + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { @@ -2675,6 +2668,7 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +// Deprecated: Do not use. func (x *DNSConfig) GetForwarderPort() int64 { if x != nil { return x.ForwarderPort @@ -3668,7 +3662,7 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xa8, 0x02, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, @@ -3684,374 +3678,372 @@ var file_management_proto_rawDesc = []byte{ 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, - 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x12, - 0x27, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, - 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, - 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, - 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, - 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, - 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, - 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, - 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, - 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, - 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, - 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, - 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, - 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, - 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, - 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, - 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, - 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, - 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, + 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, + 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, + 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, + 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, + 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, + 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, + 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, + 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, + 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, + 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, + 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, + 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, + 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, + 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, + 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, + 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, + 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, - 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, - 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, - 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, - 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, - 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, - 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, - 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, - 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, - 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, - 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, - 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, - 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, - 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, - 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, - 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, - 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, - 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, - 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, - 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, - 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, - 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, - 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, - 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, - 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, - 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, - 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, - 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, - 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, - 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, - 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, - 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, - 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, - 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, - 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, - 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, - 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, - 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, - 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, - 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, - 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, - 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, - 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, - 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, - 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, - 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, - 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, - 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, - 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, - 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, - 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, - 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, - 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, - 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, - 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, - 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, - 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, - 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, - 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, - 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, - 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, - 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, - 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, + 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, + 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, + 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, + 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, + 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, + 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, + 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, + 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, + 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, + 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, + 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, + 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, + 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, + 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, + 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, + 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, + 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, + 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, + 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, + 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, + 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, + 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, + 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, + 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, + 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, + 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, + 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, - 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, - 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, - 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, - 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, - 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, - 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, + 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, + 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, + 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, + 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, + 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, + 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, + 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, + 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, + 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, + 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, + 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, + 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, + 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, + 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, + 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, + 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, + 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, + 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, + 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, - 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, - 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, + 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, + 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, + 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, + 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4141,60 +4133,59 @@ var file_management_proto_depIdxs = []int32{ 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 22, // 22: management.NetbirdConfig.jwt:type_name -> management.JWTConfig - 3, // 23: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 47, // 24: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 19, // 25: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 27, // 26: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 26, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 33, // 29: management.NetworkMap.Routes:type_name -> management.Route - 34, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 26, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 39, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 43, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 44, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 27, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 32, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 32, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 37, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 35, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 36, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord - 38, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 42, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 45, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 42, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 42, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 42, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 70: management.ManagementService.Logout:output_type -> management.Empty - 63, // [63:71] is the sub-list for method output_type - 55, // [55:63] is the sub-list for method input_type - 55, // [55:55] is the sub-list for extension type_name - 55, // [55:55] is the sub-list for extension extendee - 0, // [0:55] is the sub-list for field type_name + 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 47, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 27, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 24, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 26, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 33, // 28: management.NetworkMap.Routes:type_name -> management.Route + 34, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 26, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 39, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 43, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 44, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 35: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 36: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 32, // 37: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 32, // 38: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 37, // 39: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 35, // 40: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 36, // 41: management.CustomZone.Records:type_name -> management.SimpleRecord + 38, // 42: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 43: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 44: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 45: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 42, // 46: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 45, // 47: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 48: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 49: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 42, // 50: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 51: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 42, // 52: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 42, // 53: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 54: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 55: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 56: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 57: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 58: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 59: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 60: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 64: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 65: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 66: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 67: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 68: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 69: management.ManagementService.Logout:output_type -> management.Empty + 62, // [62:70] is the sub-list for method output_type + 54, // [54:62] is the sub-list for method input_type + 54, // [54:54] is the sub-list for extension type_name + 54, // [54:54] is the sub-list for extension extendee + 0, // [0:54] is the sub-list for field type_name } func init() { file_management_proto_init() } diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 16737cf58..8a297e75e 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -208,8 +208,6 @@ message NetbirdConfig { RelayConfig relay = 4; FlowConfig flow = 5; - - JWTConfig jwt = 6; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) From 68f56b797d43a2bc9c3a68eed9c33866c3fcf0d6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 19 Nov 2025 13:16:47 +0100 Subject: [PATCH 064/118] [management] Add native ssh port rule on 22 (#4810) Implements feature-aware firewall rule expansion: derives peer-supported features (native SSH, portRanges) from peer version, prefers explicit Ports over PortRanges when expanding, conditionally appends a native SSH (22022) rule when policy and peer support allow, and adds helpers plus tests for SSH expansion behavior. --- management/server/types/account.go | 90 ++++++-- management/server/types/account_test.go | 266 ++++++++++++++++++++++++ 2 files changed, 341 insertions(+), 15 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 51313819f..9e86d8936 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -40,8 +40,20 @@ const ( // firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules. firewallRuleMinPortRangesVer = "0.48.0" + // firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules. + firewallRuleMinNativeSSHVer = "0.60.0" + + // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. + nativeSSHPortString = "22022" + // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. + defaultSSHPortString = "22" ) +type supportedFeatures struct { + nativeSSH bool + portRanges bool +} + type LookupMap map[string]struct{} // AccountMeta is a struct that contains a stripped down version of the Account object. @@ -1650,22 +1662,24 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error { // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { + features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) + var expanded []*FirewallRule - if len(rule.Ports) > 0 { - for _, port := range rule.Ports { - fr := base - fr.Port = port - expanded = append(expanded, &fr) - } - return expanded + for _, port := range rule.Ports { + fr := base + fr.Port = port + expanded = append(expanded, &fr) } - supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion) for _, portRange := range rule.PortRanges { + // prefer PolicyRule.Ports + if len(rule.Ports) > 0 { + break + } fr := base - if supportPortRanges { + if features.portRanges { fr.PortRange = portRange } else { // Peer doesn't support port ranges, only allow single-port ranges @@ -1677,17 +1691,63 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer expanded = append(expanded, &fr) } + if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) { + expanded = addNativeSSHRule(base, expanded) + } + return expanded } -// peerSupportsPortRanges checks if the peer version supports port ranges. -func peerSupportsPortRanges(peerVer string) bool { - if strings.Contains(peerVer, "dev") { - return true +// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured. +func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule { + shouldAdd := false + for _, fr := range expanded { + if isPortInRule(nativeSSHPortString, 22022, fr) { + return expanded + } + if isPortInRule(defaultSSHPortString, 22, fr) { + shouldAdd = true + } + } + if !shouldAdd { + return expanded } - meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) - return err == nil && meetMinVer + fr := base + fr.Port = nativeSSHPortString + return append(expanded, &fr) +} + +func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool { + return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End) +} + +// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support. +// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled +// in both management and client, we indicate to add the native port. +func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool { + return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP +} + +// peerSupportedFirewallFeatures checks if the peer version supports port ranges. +func peerSupportedFirewallFeatures(peerVer string) supportedFeatures { + if strings.Contains(peerVer, "dev") { + return supportedFeatures{true, true} + } + + var features supportedFeatures + + meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer) + features.nativeSSH = err == nil && meetMinVer + + if features.nativeSSH { + features.portRanges = true + } else { + meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) + features.portRanges = err == nil && meetMinVer + } + + return features } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 32538933a..f9aa6a1c2 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -839,6 +839,272 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, sourcePeers, 2, "expected source peers don't match") } +func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + rule *PolicyRule + base FirewallRule + expectedPorts []string + }{ + { + name: "adds port 22022 when SSH enabled on modern peer with port 22", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "adds port 22022 once when port 22 is duplicated within policy", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "80", "22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "80", "22", "22022"}, + }, + { + name: "does not add 22022 for peer with old version", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when SSHEnabled is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: false, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when ServerSSHAllowed is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: false}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 for UDP protocol", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolUDP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "udp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when port 22 not in rule", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80", "443"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"80", "443"}, + }, + { + name: "does not duplicate 22022 when already present", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "22022"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "does not duplicate 22022 when already within a port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 32000}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-32000"}, + }, + { + name: "adds 22022 when port 22 in port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "22022"}, + }, + { + name: "adds single 22022 once when port 22 in multiple port ranges", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}, {Start: 10, End: 100}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "10-100", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.50.0-dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "development suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPortsAndRanges(tt.base, tt.rule, tt.peer) + + var ports []string + for _, fr := range result { + if fr.Port != "" { + ports = append(ports, fr.Port) + } else if fr.PortRange.Start > 0 { + ports = append(ports, fmt.Sprintf("%d-%d", fr.PortRange.Start, fr.PortRange.End)) + } + } + + assert.Equal(t, tt.expectedPorts, ports, "expanded ports should match expected") + }) + } +} + func Test_FilterZoneRecordsForPeers(t *testing.T) { tests := []struct { name string From 131136439764a67187b7fbfadef1215b5cc66780 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:09:22 +0100 Subject: [PATCH 065/118] [client] Increase ssh detection timeout (#4827) --- client/cmd/ssh.go | 17 +++++++++++++---- client/ssh/client/client.go | 9 ++++++--- client/ssh/config/manager.go | 7 +------ client/ssh/detection/detection.go | 18 +++++++++--------- client/ssh/server/jwt_test.go | 6 +++--- client/wasm/cmd/main.go | 24 ++++++++++++++---------- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 70c7dbcff..92857c637 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { logOutput = firstLogFile } - if err := util.InitLog(logLevel, logOutput); err != nil { + + proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(proxyLogLevel, logOutput); err != nil { return fmt.Errorf("init log: %w", err) } @@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{ } func sshDetectFn(cmd *cobra.Command, args []string) error { - if err := util.InitLog(logLevel, "console"); err != nil { + detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(detectLogLevel, "console"); err != nil { os.Exit(detection.ServerTypeRegular.ExitCode()) } @@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error { port, err := strconv.Atoi(portStr) if err != nil { + log.Debugf("invalid port %q: %v", portStr, err) os.Exit(detection.ServerTypeRegular.ExitCode()) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port) + ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) if err != nil { + log.Debugf("SSH server detection failed: %v", err) + cancel() os.Exit(detection.ServerTypeRegular.ExitCode()) } + cancel() os.Exit(serverType.ExitCode()) return nil } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 882056374..31b80317a 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo return nil, fmt.Errorf("parse port %s: %w", portStr, err) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port) if err != nil { - return nil, fmt.Errorf("SSH server detection failed: %w", err) + return nil, fmt.Errorf("SSH server detection: %w", err) } if !serverType.RequiresJWT() { diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 03a136de3..cc47fd2d2 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { hostLine := strings.Join(deduplicatedPatterns, " ") config := fmt.Sprintf("Host %s\n", hostLine) - - if runtime.GOOS == "windows" { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - } else { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath) - } + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) config += " PreferredAuthentications password,publickey,keyboard-interactive\n" config += " PasswordAuthentication yes\n" config += " PubkeyAuthentication yes\n" diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go index 487f4665a..f23ea4c37 100644 --- a/client/ssh/detection/detection.go +++ b/client/ssh/detection/detection.go @@ -3,6 +3,7 @@ package detection import ( "bufio" "context" + "fmt" "net" "strconv" "strings" @@ -19,8 +20,8 @@ const ( // JWTRequiredMarker is appended to responses when JWT is required JWTRequiredMarker = "NetBird-JWT-Required" - // Timeout is the timeout for SSH server detection - Timeout = 5 * time.Second + // DefaultTimeout is the default timeout for SSH server detection + DefaultTimeout = 5 * time.Second ) type ServerType string @@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i conn, err := dialer.DialContext(ctx, "tcp", targetAddr) if err != nil { - log.Debugf("SSH connection failed for detection: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err) } defer conn.Close() - if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { - log.Debugf("set read deadline: %v", err) - return ServerTypeRegular, nil + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetReadDeadline(deadline); err != nil { + return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err) + } } reader := bufio.NewReader(conn) serverBanner, err := reader.ReadString('\n') if err != nil { - log.Debugf("read SSH banner: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err) } serverBanner = strings.TrimSpace(serverBanner) diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index e22bdfb06..1f3bac76d 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) { require.NoError(t, err) port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) if err != nil { t.Logf("Detection failed: %v", err) @@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) { portNoJWT, err := strconv.Atoi(portStrNoJWT) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) @@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) { port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 4dc14a1ca..238e272fa 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -19,9 +19,10 @@ import ( ) const ( - clientStartTimeout = 30 * time.Second - clientStopTimeout = 10 * time.Second - defaultLogLevel = "warn" + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" + defaultSSHDetectionTimeout = 20 * time.Second ) func main() { @@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { host := args[0].String() port := args[1].Int() + timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds()) + if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() { + timeoutMs = args[2].Int() + if timeoutMs <= 0 { + return js.ValueOf("error: timeout must be positive") + } + } + return createPromise(func(resolve, reject js.Value) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) defer cancel() - serverType, err := detectSSHServerType(ctx, client, host, port) + serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port) if err != nil { reject.Invoke(err.Error()) return @@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { }) } -// detectSSHServerType detects SSH server type using NetBird network connection -func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) { - return sshdetection.DetectSSHServerType(ctx, client, host, port) -} - // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{}) From 32146e576d9e817a4fb84a87631521eceed5b503 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Fri, 21 Nov 2025 09:36:33 -0300 Subject: [PATCH 066/118] [android] allow selection/deselection of network resources on android peers (#4607) --- client/android/client.go | 115 ++++++- client/android/network_domains.go | 56 ++++ client/android/networks.go | 14 +- client/android/peer_notifier.go | 7 + client/android/peer_routes.go | 20 ++ client/android/platform_files.go | 10 + client/android/route_command.go | 67 ++++ client/iface/device/device_android.go | 52 ++- .../iface/device/device_netstack_android.go | 7 + client/iface/device/renewable_tun.go | 309 ++++++++++++++++++ client/iface/device_android.go | 1 + client/iface/iface_create.go | 4 + client/iface/iface_create_android.go | 7 + client/iface/iface_create_darwin.go | 4 + client/internal/connect.go | 2 + client/internal/engine.go | 14 +- client/internal/engine_test.go | 4 + client/internal/iface_common.go | 1 + 18 files changed, 666 insertions(+), 28 deletions(-) create mode 100644 client/android/network_domains.go create mode 100644 client/android/peer_routes.go create mode 100644 client/android/platform_files.go create mode 100644 client/android/route_command.go create mode 100644 client/iface/device/renewable_tun.go diff --git a/client/android/client.go b/client/android/client.go index 86fb1445d..2943702c6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,10 +4,13 @@ package android import ( "context" + "fmt" "os" "slices" "sync" + "golang.org/x/exp/maps" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/device" @@ -16,10 +19,13 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -62,17 +68,18 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener + stateFile string connectClient *internal.ConnectClient } // NewClient instantiate a new Client -func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: cfgFile, + cfgFile: platformFiles.ConfigurationFilePath(), deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -80,6 +87,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, networkChangeListener: networkChangeListener, + stateFile: platformFiles.StateFilePath(), } } @@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // Stop the internal client and free the resources @@ -156,6 +164,19 @@ func (c *Client) Stop() { c.ctxCancel() } +func (c *Client) RenewTun(fd int) error { + if c.connectClient == nil { + return fmt.Errorf("engine not running") + } + + e := c.connectClient.Engine() + if e == nil { + return fmt.Errorf("engine not initialized") + } + + return e.RenewTun(fd) +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray { p.IP, p.FQDN, p.ConnStatus.String(), + PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray { return nil } + routeSelector := routeManager.GetRouteSelector() + if routeSelector == nil { + log.Error("could not get route selector") + return nil + } + networkArray := &NetworkArray{ items: make([]Network, 0), } + resolvedDomains := c.recorder.GetResolvedDomainsStates() + for id, routes := range routeManager.GetClientRoutesWithNetID() { if len(routes) == 0 { continue } r := routes[0] + domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) netStr := r.Network.String() + if r.IsDynamic() { netStr = r.Domains.SafeString() } - peer, err := c.recorder.GetPeer(routes[0].Peer) + routePeer, err := c.recorder.GetPeer(routes[0].Peer) if err != nil { log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } network := Network{ - Name: string(id), - Network: netStr, - Peer: peer.FQDN, - Status: peer.ConnStatus.String(), + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: routeSelector.IsSelected(id), + Domains: domains, } networkArray.Add(network) } @@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } +func (c *Client) toggleRoute(command routeCommand) error { + return command.toggleRoute() +} + +func (c *Client) getRouteManager() (routemanager.Manager, error) { + client := c.connectClient + if client == nil { + return nil, fmt.Errorf("not connected") + } + + engine := client.Engine() + if engine == nil { + return nil, fmt.Errorf("engine is not running") + } + + manager := engine.GetRouteManager() + if manager == nil { + return nil, fmt.Errorf("could not get route manager") + } + + return manager, nil +} + +func (c *Client) SelectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(selectRouteCommand{route: route, manager: manager}) +} + +func (c *Client) DeselectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(deselectRouteCommand{route: route, manager: manager}) +} + +// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain +// with its resolved IP addresses from the provided resolvedDomains map. +func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains { + domains := NetworkDomains{} + + for _, d := range route.Domains { + networkDomain := NetworkDomain{ + Address: d.SafeString(), + } + + if info, exists := resolvedDomains[d]; exists { + for _, prefix := range info.Prefixes { + networkDomain.addResolvedIP(prefix.Addr().String()) + } + } + + domains.Add(&networkDomain) + } + + return domains +} + func exportEnvList(list *EnvList) { if list == nil { return diff --git a/client/android/network_domains.go b/client/android/network_domains.go new file mode 100644 index 000000000..a459bdc23 --- /dev/null +++ b/client/android/network_domains.go @@ -0,0 +1,56 @@ +//go:build android + +package android + +import "fmt" + +type ResolvedIPs struct { + resolvedIPs []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.resolvedIPs = append(r.resolvedIPs, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) (string, error) { + if i < 0 || i >= len(r.resolvedIPs) { + return "", fmt.Errorf("%d is out of range", i) + } + return r.resolvedIPs[i], nil +} + +func (r *ResolvedIPs) Size() int { + return len(r.resolvedIPs) +} + +type NetworkDomain struct { + Address string + resolvedIPs ResolvedIPs +} + +func (d *NetworkDomain) addResolvedIP(resolvedIP string) { + d.resolvedIPs.Add(resolvedIP) +} + +func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type NetworkDomains struct { + domains []*NetworkDomain +} + +func (n *NetworkDomains) Add(domain *NetworkDomain) { + n.domains = append(n.domains, domain) +} + +func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) { + if i < 0 || i >= len(n.domains) { + return nil, fmt.Errorf("%d is out of range", i) + } + return n.domains[i], nil +} + +func (n *NetworkDomains) Size() int { + return len(n.domains) +} diff --git a/client/android/networks.go b/client/android/networks.go index aa130420b..3c3a25939 100644 --- a/client/android/networks.go +++ b/client/android/networks.go @@ -3,10 +3,16 @@ package android type Network struct { - Name string - Network string - Peer string - Status string + Name string + Network string + Peer string + Status string + IsSelected bool + Domains NetworkDomains +} + +func (n Network) GetNetworkDomains() *NetworkDomains { + return &n.Domains } type NetworkArray struct { diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 1f5564c72..b03947da1 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build android + package android // PeerInfo describe information about the peers. It designed for the UI usage @@ -5,6 +7,11 @@ type PeerInfo struct { IP string FQDN string ConnStatus string // Todo replace to enum + Routes PeerRoutes +} + +func (p *PeerInfo) GetPeerRoutes() *PeerRoutes { + return &p.Routes } // PeerInfoArray is a wrapper of []PeerInfo diff --git a/client/android/peer_routes.go b/client/android/peer_routes.go new file mode 100644 index 000000000..bb46d609f --- /dev/null +++ b/client/android/peer_routes.go @@ -0,0 +1,20 @@ +//go:build android + +package android + +import "fmt" + +type PeerRoutes struct { + routes []string +} + +func (p *PeerRoutes) Get(i int) (string, error) { + if i < 0 || i >= len(p.routes) { + return "", fmt.Errorf("%d is out of range", i) + } + return p.routes[i], nil +} + +func (p *PeerRoutes) Size() int { + return len(p.routes) +} diff --git a/client/android/platform_files.go b/client/android/platform_files.go new file mode 100644 index 000000000..f0c369750 --- /dev/null +++ b/client/android/platform_files.go @@ -0,0 +1,10 @@ +//go:build android + +package android + +// PlatformFiles groups paths to files used internally by the engine that can't be created/modified +// at their default locations due to android OS restrictions. +type PlatformFiles interface { + ConfigurationFilePath() string + StateFilePath() string +} diff --git a/client/android/route_command.go b/client/android/route_command.go new file mode 100644 index 000000000..b47d5ca6c --- /dev/null +++ b/client/android/route_command.go @@ -0,0 +1,67 @@ +//go:build android + +package android + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/route" +) + +func executeRouteToggle(id string, manager routemanager.Manager, + operationName string, + routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error { + netID := route.NetID(id) + routes := []route.NetID{netID} + + log.Debugf("%s with id: %s", operationName, id) + + if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when %s: %s", operationName, err) + return fmt.Errorf("error %s: %w", operationName, err) + } + + manager.TriggerSelection(manager.GetClientRoutes()) + + return nil +} + +type routeCommand interface { + toggleRoute() error +} + +type selectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (s selectRouteCommand) toggleRoute() error { + routeSelector := s.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error { + return routeSelector.SelectRoutes(routes, true, allRoutes) + } + + return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation) +} + +type deselectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (d deselectRouteCommand) toggleRoute() error { + routeSelector := d.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes) +} diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 48346fc0f..198343fbd 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -3,6 +3,7 @@ package device import ( + "fmt" "strings" log "github.com/sirupsen/logrus" @@ -19,11 +20,12 @@ import ( // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform type WGTunDevice struct { - address wgaddr.Address - port int - key string - mtu uint16 - iceBind *bind.ICEBind + address wgaddr.Address + port int + key string + mtu uint16 + iceBind *bind.ICEBind + // todo: review if we can eliminate the TunAdapter tunAdapter TunAdapter disableDNS bool @@ -32,17 +34,19 @@ type WGTunDevice struct { filteredDevice *FilteredDevice udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer + renewableTun *RenewableTUN } func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ - address: address, - port: port, - key: key, - mtu: mtu, - iceBind: iceBind, - tunAdapter: tunAdapter, - disableDNS: disableDNS, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: iceBind, + tunAdapter: tunAdapter, + disableDNS: disableDNS, + renewableTun: NewRenewableTUN(), } } @@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } - tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd) + unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd) if err != nil { _ = unix.Close(fd) log.Errorf("failed to create Android interface: %s", err) return nil, err } + + t.renewableTun.AddDevice(unmonitoredTUN) + t.name = name - t.filteredDevice = newDeviceFilter(tunDevice) + t.filteredDevice = newDeviceFilter(t.renewableTun) log.Debugf("attaching to interface %v", name) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) @@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } +func (t *WGTunDevice) RenewTun(fd int) error { + if t.device == nil { + return fmt.Errorf("device not initialized") + } + + unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd) + if err != nil { + _ = unix.Close(fd) + log.Errorf("failed to renew Android interface: %s", err) + return err + } + + t.renewableTun.AddDevice(unmonitoredTUN) + + return nil +} + func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { // todo implement return nil diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go index 45ae8ba7d..f1a77d40a 100644 --- a/client/iface/device/device_netstack_android.go +++ b/client/iface/device/device_netstack_android.go @@ -2,6 +2,13 @@ package device +import "fmt" + func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { return t.create() } + +func (t *TunNetstackDevice) RenewTun(fd int) error { + // Doesn't make sense in Android for Netstack. + return fmt.Errorf("this function has not been implemented in Netstack for Android") +} diff --git a/client/iface/device/renewable_tun.go b/client/iface/device/renewable_tun.go new file mode 100644 index 000000000..a501eebbb --- /dev/null +++ b/client/iface/device/renewable_tun.go @@ -0,0 +1,309 @@ +//go:build android + +package device + +import ( + "io" + "os" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" +) + +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +// +// It also redirects tun.Device's Events() to a separate goroutine +// and closes it when Close is called. +// +// The WaitGroup and CloseOnce fields are used to ensure that the +// goroutine is awaited and closed only once. +type closeAwareDevice struct { + isClosed atomic.Bool + tun.Device + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newClosableDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel (RenewableTUN's events channel). +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() +} + +type RenewableTUN struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + events chan tun.Event + closed atomic.Bool +} + +func NewRenewableTUN() *RenewableTUN { + r := &RenewableTUN{ + devices: make([]*closeAwareDevice, 0), + mu: sync.Mutex{}, + events: make(chan tun.Event, 16), + } + r.cond = sync.NewCond(&r.mu) + return r +} + +func (r *RenewableTUN) File() *os.File { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// Read reads from an underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries reading from the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + for { + dev := r.peekLast() + if dev == nil { + // wait until AddDevice() signals a new device via cond.Broadcast() + if !r.waitForDevice() { // returns false if the renewable TUN itself is closed + return 0, io.EOF + } + continue + } + + n, err = dev.Read(bufs, sizes, offset) + if err == nil { + return n, nil + } + + // swap in progress; retry on the newest instead of returning the error + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return n, err // propagate non-swap error + } +} + +// Write writes to underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries writing to the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} + +func (r *RenewableTUN) MTU() (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return 0, err + } +} + +func (r *RenewableTUN) Name() (string, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return "", io.EOF + } + continue + } + name, err := dev.Name() + if err == nil { + return name, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return "", err + } +} + +// Events returns a channel that is fed events from the underlying tun.Device's events channel +// once it is added. +func (r *RenewableTUN) Events() <-chan tun.Event { + return r.events +} + +func (r *RenewableTUN) Close() error { + // Attempts to set the RenewableTUN closed flag to true. + // If it's already true, returns immediately. + if !r.closed.CompareAndSwap(false, true) { + return nil // already closed: idempotent + } + r.mu.Lock() + devices := r.devices + r.devices = nil + r.cond.Broadcast() + r.mu.Unlock() + + var lastErr error + + log.Debugf("closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + log.Debugf("error closing a device: %v", err) + lastErr = err + } + } + + close(r.events) + return lastErr +} + +func (r *RenewableTUN) BatchSize() int { + return 1 +} + +func (r *RenewableTUN) AddDevice(device tun.Device) { + r.mu.Lock() + if r.closed.Load() { + r.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(r.devices) > 0 { + toClose = r.devices[len(r.devices)-1] + } + + cad := newClosableDevice(device) + cad.redirectEvents(r.events) + + r.devices = []*closeAwareDevice{cad} + r.cond.Broadcast() + + r.mu.Unlock() + + if toClose != nil { + if err := toClose.Close(); err != nil { + log.Debugf("error closing last device: %v", err) + } + } +} + +func (r *RenewableTUN) waitForDevice() bool { + r.mu.Lock() + defer r.mu.Unlock() + + for len(r.devices) == 0 && !r.closed.Load() { + r.cond.Wait() + } + return !r.closed.Load() +} + +func (r *RenewableTUN) peekLast() *closeAwareDevice { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.devices) == 0 { + return nil + } + + return r.devices[len(r.devices)-1] +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index cdfcea48d..3899bf426 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,5 +21,6 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + RenewTun(fd int) error GetICEBind() device.EndpointManager } diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..13ae9393c 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -24,3 +24,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on non mobile") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on non-android") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..d2d9eb70e 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -6,6 +6,7 @@ import ( // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. +// todo: review does this function really necessary or can we merge it with iOS func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { w.mu.Lock() defer w.mu.Unlock() @@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s func (w *WGIface) Create() error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.RenewTun(fd) +} diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..0b7cd36ef 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -39,3 +39,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on this platform") +} diff --git a/client/internal/connect.go b/client/internal/connect.go index 6ad5f264b..5a5f4f63c 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid( networkChangeListener listener.NetworkChangeListener, dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, + stateFilePath string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid( NetworkChangeListener: networkChangeListener, HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 94c948398..76265dd77 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -255,7 +255,7 @@ func NewEngine( sm := profilemanager.NewServiceManager("") path := sm.GetStatePath() - if runtime.GOOS == "ios" { + if runtime.GOOS == "ios" || runtime.GOOS == "android" { if !fileExists(mobileDep.StateFilePath) { err := createFile(mobileDep.StateFilePath) if err != nil { @@ -1831,6 +1831,18 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +func (e *Engine) RenewTun(fd int) error { + e.syncMsgMux.Lock() + wgInterface := e.wgInterface + e.syncMsgMux.Unlock() + + if wgInterface == nil { + return fmt.Errorf("wireguard interface not initialized") + } + + return wgInterface.RenewTun(fd) +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder( enabled bool, diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d6ab7391e..9252ce13e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -110,6 +110,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RenewTun(_ int) error { + return nil +} + func (m *MockWGIface) RemoveEndpointAddress(_ string) error { return nil } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 98fe01912..90b06cbd1 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -20,6 +20,7 @@ import ( type wgIfaceBase interface { Create() error CreateOnAndroid(routeRange []string, ip string, domains []string) error + RenewTun(fd int) error IsUserspaceBind() bool Name() string Address() wgaddr.Address From 7fb1a2fe312f8549095465469e64a4af3b155a59 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sat, 22 Nov 2025 01:23:33 +0100 Subject: [PATCH 067/118] [management] removed TestBufferUpdateAccountPeers because it was incorrect (#4839) --- .../network_map/controller/controller_test.go | 135 ------------------ 1 file changed, 135 deletions(-) diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go index baaffe677..90e7b6e18 100644 --- a/management/internals/controllers/network_map/controller/controller_test.go +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -1,16 +1,9 @@ package controller import ( - "context" - "sync" - "sync/atomic" "testing" - "time" - - "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/internals/controllers/network_map" - "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -114,131 +107,3 @@ func TestComputeForwarderPort(t *testing.T) { t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) } } - -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} From 290fe2d8b91483073f1c9dba439f464a910cfc61 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 22 Nov 2025 10:10:18 +0100 Subject: [PATCH 068/118] [client/management/signal/relay] Update go.mod to use Go 1.24.10 and upgrade x/crypto dependencies (#4828) Upgrade Go toolchain and golang.org/x/* deps to 1.24.10, standardize GitHub Actions to derive Go version from go.mod and adjust checkout ordering, raise WASM size limit to 55 MB, update FreeBSD tarball and gomobile refs, fix a few format-string/logging calls, treat usernames ending with $ as system accounts, and add Windows tests. --- .github/workflows/golang-test-darwin.yml | 7 +- .github/workflows/golang-test-freebsd.yml | 2 +- .github/workflows/golang-test-linux.yml | 71 +++++------ .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/mobile-build-validation.yml | 8 +- .github/workflows/release.yml | 8 +- .../workflows/test-infrastructure-files.yml | 8 +- .github/workflows/wasm-build-validation.yml | 8 +- client/internal/dns/upstream.go | 2 +- client/ssh/testutil/user_helpers.go | 3 +- client/ssh/testutil/user_helpers_test.go | 115 ++++++++++++++++++ go.mod | 22 ++-- go.sum | 40 +++--- .../policies/posture_checks_handler_test.go | 2 +- management/server/posture_checks.go | 2 +- 16 files changed, 210 insertions(+), 92 deletions(-) create mode 100644 client/ssh/testutil/user_helpers_test.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 4571ce753..9c4c35d21 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -15,13 +15,14 @@ jobs: name: "Client / Unit" runs-on: macos-latest steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index cdd0910a4..b03313bbd 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -25,7 +25,7 @@ jobs: release: "14.2" prepare: | pkg install -y curl pkgconf xorg - GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz" + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" GO_URL="https://go.dev/dl/$GO_TARBALL" curl -vLO "$GO_URL" tar -C /usr/local -vxzf "$GO_TARBALL" diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba36c013b..c09bfab39 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -30,7 +30,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment @@ -106,15 +106,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -151,15 +151,15 @@ jobs: needs: [ build-cache ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment id: go-env run: | @@ -200,7 +200,7 @@ jobs: -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e CONTAINER=${CONTAINER} \ - golang:1.23-alpine \ + golang:1.24-alpine \ sh -c ' \ apk update; apk add --no-cache \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ @@ -220,15 +220,15 @@ jobs: raceFlag: "-race" runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -270,15 +270,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -321,15 +321,15 @@ jobs: store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -408,15 +408,16 @@ jobs: -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ -p 9090:9090 \ prom/prometheus - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + cache: false + - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -497,15 +498,15 @@ jobs: -p 9090:9090 \ prom/prometheus + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -561,15 +562,15 @@ jobs: store: [ 'sqlite', 'postgres'] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2083c0721..43357c45f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,7 +24,7 @@ jobs: uses: actions/setup-go@v5 id: go with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2845b05a5..c524f6f6b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index c7d43695b..8325fbf2d 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Setup Android SDK uses: android-actions/setup-android@v3 with: @@ -39,7 +39,7 @@ jobs: - name: Setup NDK run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build android netbird lib @@ -56,9 +56,9 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build iOS netbird lib diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9741f541..a9bc1b979 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest-m env: flags: "" steps: @@ -40,7 +40,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -136,7 +136,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -200,7 +200,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index 3855baba2..f4513e0e1 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -67,10 +67,13 @@ jobs: - name: Install curl run: sudo apt-get install -y curl + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Cache Go modules uses: actions/cache@v4 @@ -80,9 +83,6 @@ jobs: restore-keys: | ${{ runner.os }}-go- - - name: Checkout code - uses: actions/checkout@v4 - - name: Setup MySQL privileges if: matrix.store == 'mysql' run: | diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index e4ac799bc..4100e16dd 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: Install golangci-lint @@ -45,7 +45,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Build Wasm client run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd env: @@ -60,8 +60,8 @@ jobs: echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" - if [ ${SIZE} -gt 52428800 ]; then - echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + if [ ${SIZE} -gt 57671680 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" exit 1 fi diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c19e0acb5..2a92fd6d8 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add timeoutMsg += " " + peerInfo } timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warnf(timeoutMsg) + logger.Warn(timeoutMsg) } func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go index 0c1222078..8960d8dd0 100644 --- a/client/ssh/testutil/user_helpers.go +++ b/client/ssh/testutil/user_helpers.go @@ -72,7 +72,8 @@ func IsSystemAccount(username string) bool { return true } } - return false + + return strings.HasSuffix(username, "$") } // RegisterTestUserCleanup registers a test user for cleanup diff --git a/client/ssh/testutil/user_helpers_test.go b/client/ssh/testutil/user_helpers_test.go new file mode 100644 index 000000000..db2f5f06d --- /dev/null +++ b/client/ssh/testutil/user_helpers_test.go @@ -0,0 +1,115 @@ +package testutil + +import ( + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUserCurrentBehavior validates user.Current() behavior on Windows. +// When running as SYSTEM on a domain-joined machine, user.Current() returns: +// - Username: Computer account name (e.g., "DOMAIN\MACHINE$") +// - SID: SYSTEM SID (S-1-5-18) +func TestUserCurrentBehavior(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid) + + // When running as SYSTEM, validate expected behavior + if currentUser.Uid == "S-1-5-18" { + t.Run("SYSTEM_account_behavior", func(t *testing.T) { + // SID must be S-1-5-18 for SYSTEM + require.Equal(t, "S-1-5-18", currentUser.Uid, + "SYSTEM account must have SID S-1-5-18") + + // Username can be either "NT AUTHORITY\SYSTEM" (standalone) + // or "DOMAIN\MACHINE$" (domain-joined) + username := currentUser.Username + isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY") + isComputerAccount := strings.HasSuffix(username, "$") + + assert.True(t, isNTAuthority || isComputerAccount, + "Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s", + username) + + if isComputerAccount { + t.Logf("SYSTEM as computer account: %s", username) + } else if isNTAuthority { + t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM") + } + }) + } + + // Validate that IsSystemAccount correctly identifies system accounts + t.Run("IsSystemAccount_validation", func(t *testing.T) { + // Test with current user if it's a system account + if currentUser.Uid == "S-1-5-18" || // SYSTEM + currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE + currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE + + result := IsSystemAccount(currentUser.Username) + assert.True(t, result, + "IsSystemAccount should recognize system account: %s (SID: %s)", + currentUser.Username, currentUser.Uid) + } + + // Test explicit cases + testCases := []struct { + username string + expected bool + reason string + }{ + {"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"}, + {"system", true, "system"}, + {"SYSTEM", true, "SYSTEM (case insensitive)"}, + {"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"}, + {"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"}, + {"DOMAIN\\MACHINE$", true, "computer account (ends with $)"}, + {"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"}, + {"Administrator", false, "Administrator is not a system account"}, + {"alice", false, "regular user"}, + {"DOMAIN\\alice", false, "domain user"}, + } + + for _, tc := range testCases { + t.Run(tc.username, func(t *testing.T) { + result := IsSystemAccount(tc.username) + assert.Equal(t, tc.expected, result, + "IsSystemAccount(%q) should be %v because: %s", + tc.username, tc.expected, tc.reason) + }) + } + }) +} + +// TestComputerAccountDetection validates computer account detection. +func TestComputerAccountDetection(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + computerAccounts := []string{ + "MACHINE$", + "WIN2K19-C2$", + "DOMAIN\\MACHINE$", + "WORKGROUP\\SERVER$", + "server.domain.com$", + } + + for _, account := range computerAccounts { + t.Run(account, func(t *testing.T) { + result := IsSystemAccount(account) + assert.True(t, result, + "Computer account %q should be recognized as system account", account) + }) + } +} diff --git a/go.mod b/go.mod index 45a36190d..a72d09dc3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.23.1 +go 1.24.10 require ( cunicu.li/go-rosenpass v0.4.0 @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.41.0 - golang.org/x/sys v0.35.0 + golang.org/x/crypto v0.45.0 + golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -105,12 +105,12 @@ require ( go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.26.0 - golang.org/x/net v0.42.0 + golang.org/x/mobile v0.0.0-20251113184115-a159579294ab + golang.org/x/mod v0.30.0 + golang.org/x/net v0.47.0 golang.org/x/oauth2 v0.30.0 - golang.org/x/sync v0.16.0 - golang.org/x/term v0.34.0 + golang.org/x/sync v0.18.0 + golang.org/x/term v0.37.0 golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 @@ -251,9 +251,9 @@ require ( go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/image v0.24.0 // indirect - golang.org/x/text v0.28.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/image v0.33.0 // indirect + golang.org/x/text v0.31.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/go.sum b/go.sum index ec68a8f59..011a5f199 100644 --- a/go.sum +++ b/go.sum @@ -600,19 +600,19 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ= -golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8= +golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= +golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -622,8 +622,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -647,8 +647,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= @@ -665,8 +665,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -703,8 +703,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -717,8 +717,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= -golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -730,8 +730,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -749,8 +749,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index 8c60d6fe8..35198da32 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } return postureChecks, nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index ac8ea35de..9a743eb8c 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -158,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St // validatePostureChecks validates the posture checks. func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. From 131d7a36943892ee33631158fb6cd5e878f5a4fd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 22 Nov 2025 18:57:07 +0100 Subject: [PATCH 069/118] [client] Make mss clamping optional for nftables (#4843) --- client/firewall/nftables/router_linux.go | 35 +++++++++--------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 6192c92aa..e4debc179 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -91,11 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, fmt.Errorf("load filter table: %w", err) - } + log.Warnf("failed to load filter table, skipping accept rules: %v", err) } return r, nil @@ -175,7 +171,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("unable to list tables: %v", err) + return nil, fmt.Errorf("list tables: %w", err) } for _, table := range tables { @@ -193,8 +189,6 @@ func (r *router) createContainers() error { Table: r.workTable, }) - insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) - prio := *nftables.ChainPriorityNATSource - 1 r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, @@ -236,9 +230,12 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) - // Add the single NAT rule that matches on mark - if err := r.addPostroutingRules(); err != nil { - return fmt.Errorf("add single nat rule: %v", err) + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.addPostroutingRules() + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("initialize tables: %v", err) } if err := r.addMSSClampingRules(); err != nil { @@ -250,11 +247,7 @@ func (r *router) createContainers() error { } if err := r.refreshRulesMap(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - if err := r.conn.Flush(); err != nil { - return fmt.Errorf("initialize tables: %v", err) + log.Errorf("failed to refresh rules: %s", err) } return nil @@ -695,7 +688,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } // addPostroutingRules adds the masquerade rules -func (r *router) addPostroutingRules() error { +func (r *router) addPostroutingRules() { // First masquerade rule for traffic coming in from WireGuard interface exprs := []expr.Any{ // Match on the first fwmark @@ -761,8 +754,6 @@ func (r *router) addPostroutingRules() error { Chain: r.chains[chainNameRoutingNat], Exprs: exprs2, }) - - return nil } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. @@ -839,7 +830,7 @@ func (r *router) addMSSClampingRules() error { Exprs: exprsOut, }) - return nil + return r.conn.Flush() } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls @@ -1068,7 +1059,7 @@ func (r *router) acceptFilterRulesNftables() error { } r.conn.InsertRule(inputRule) - return nil + return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { @@ -1196,7 +1187,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf(" unable to list rules: %v", err) + return fmt.Errorf("list rules: %w", err) } for _, rule := range rules { if len(rule.UserData) > 0 { From ba2e9b6d88e57a760341ac83473ede264ea32204 Mon Sep 17 00:00:00 2001 From: Aziz Hasanain Date: Mon, 24 Nov 2025 14:12:51 +0300 Subject: [PATCH 070/118] [management] Fix SSH JWT issuer derivation for IDPs with path components (#4844) --- management/internals/shared/grpc/conversion.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 148db4e37..2b15fe4b8 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -369,7 +369,7 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi } issuer := strings.TrimSpace(config.AuthIssuer) - if issuer == "" || deviceFlowConfig != nil { + if issuer == "" && deviceFlowConfig != nil { if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" { issuer = d } From 20973063d865414aa551bd3daa73078a66a1ea9b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 24 Nov 2025 17:50:08 +0100 Subject: [PATCH 071/118] [client] Support disable search domain for custom zones (#4826) Two new boolean flags, SearchDomainDisabled and SkipPTRProcess, are added to CustomZone and its protobuf; they are propagated through the engine to DNS host logic. Host matching now uses SearchDomainDisabled directly, and PTR collection skips zones with SkipPTRProcess; reverse zones are initialized with SearchDomainDisabled: true. --- client/internal/dns.go | 8 +- client/internal/dns/host.go | 8 +- client/internal/engine.go | 4 +- dns/dns.go | 4 + shared/management/proto/management.pb.go | 312 ++++++++++++----------- shared/management/proto/management.proto | 2 + 6 files changed, 183 insertions(+), 155 deletions(-) diff --git a/client/internal/dns.go b/client/internal/dns.go index 5e604bec5..3c68e4d00 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple var records []nbdns.SimpleRecord for _, zone := range config.CustomZones { + if zone.SkipPTRProcess { + continue + } for _, record := range zone.Records { if record.Type != int(dns.TypeA) { continue @@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) { records := collectPTRRecords(config, network) reverseZone := nbdns.CustomZone{ - Domain: zoneName, - Records: records, + Domain: zoneName, + Records: records, + SearchDomainDisabled: true, } config.CustomZones = append(config.CustomZones, reverseZone) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index fa474afde..f7dc46a6b 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -11,11 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -const ( - ipv4ReverseZone = ".in-addr.arpa." - ipv6ReverseZone = ".ip6.arpa." -) - type hostManager interface { applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error @@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H } for _, customZone := range dnsConfig.CustomZones { - matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), - MatchOnly: matchOnly, + MatchOnly: customZone.SearchDomainDisabled, }) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 76265dd77..0ff1006cd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1207,7 +1207,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns for _, zone := range protoDNSConfig.GetCustomZones() { dnsZone := nbdns.CustomZone{ - Domain: zone.GetDomain(), + Domain: zone.GetDomain(), + SearchDomainDisabled: zone.GetSearchDomainDisabled(), + SkipPTRProcess: zone.GetSkipPTRProcess(), } for _, record := range zone.Records { dnsRecord := nbdns.SimpleRecord{ diff --git a/dns/dns.go b/dns/dns.go index cf089d4ed..aa0e16eb1 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -45,6 +45,10 @@ type CustomZone struct { Domain string // Records custom zone records Records []SimpleRecord + // SearchDomainDisabled indicates whether to add match domains to a search domains list or not + SearchDomainDisabled bool + // SkipPTRProcess indicates whether a client should process PTR records from custom zones + SkipPTRProcess bool } // SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 45211b7f4..2e4cf2644 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.32.1 // source: management.proto package proto @@ -2682,8 +2682,10 @@ type CustomZone struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` - Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` + Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"` + SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"` } func (x *CustomZone) Reset() { @@ -2732,6 +2734,20 @@ func (x *CustomZone) GetRecords() []*SimpleRecord { return nil } +func (x *CustomZone) GetSearchDomainDisabled() bool { + if x != nil { + return x.SearchDomainDisabled + } + return false +} + +func (x *CustomZone) GetSkipPTRProcess() bool { + if x != nil { + return x.SkipPTRProcess + } + return false +} + // SimpleRecord represents a dns.SimpleRecord type SimpleRecord struct { state protoimpl.MessageState @@ -3893,157 +3909,163 @@ var file_management_proto_rawDesc = []byte{ 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, - 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, - 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, - 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, - 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, - 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, - 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, - 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, - 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, + 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, + 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, + 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, + 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, + 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, + 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, + 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, + 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, + 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, + 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, + 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, + 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, + 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, + 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, + 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, - 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, - 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, - 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, - 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, - 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, - 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, - 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, - 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, - 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, - 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, - 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, - 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, + 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, - 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, - 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, - 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, - 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, - 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, - 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, - 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, - 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, - 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, - 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, - 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, + 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, + 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, + 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, + 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, + 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, + 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, + 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, + 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, - 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, - 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, + 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, + 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, + 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, + 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, + 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, + 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 8a297e75e..dc60b026d 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -433,6 +433,8 @@ message DNSConfig { message CustomZone { string Domain = 1; repeated SimpleRecord Records = 2; + bool SearchDomainDisabled = 3; + bool SkipPTRProcess = 4; } // SimpleRecord represents a dns.SimpleRecord From 7285fef0f041483719b0d926fab70d315e7d3314 Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:51:16 +0100 Subject: [PATCH 072/118] feat: Add support for displaying device code (UserCode) on Android TV SSO flow (#4800) - Modified URLOpener interface to pass userCode alongside URL in login.go - added ability to force device auth flow --- client/android/client.go | 4 ++-- client/android/login.go | 16 ++++++++-------- client/cmd/login.go | 2 +- client/internal/auth/oauth.go | 9 +++++++-- client/ios/NetBirdSDK/client.go | 2 +- client/server/server.go | 4 ++-- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index 2943702c6..0d5474c4b 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -92,7 +92,7 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, @@ -115,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.login(urlOpener) + err = auth.login(urlOpener, isAndroidTV) if err != nil { return err } diff --git a/client/android/login.go b/client/android/login.go index 16df24ba8..4d4c7a650 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -32,7 +32,7 @@ type ErrListener interface { // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { - Open(string) + Open(url string, userCode string) OnLoginSuccess() } @@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) { go func() { - err := a.login(urlOpener) + err := a.login(urlOpener, isAndroidTV) if err != nil { resultListener.OnError(err) } else { @@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { }() } -func (a *Auth) login(urlOpener URLOpener) error { +func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { var needsLogin bool // check if we need to generate JWT token @@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error { jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener) + tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error { return nil } -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") +func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "") if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) } - go urlOpener.Open(flowInfo.VerificationURIComplete) + go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) diff --git a/client/cmd/login.go b/client/cmd/login.go index b0c877faa..2ddcccc8a 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -332,7 +332,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro hint = profileState.Email } - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint) if err != nil { return nil, err } diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 9fbd6cf5f..85a166005 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } +func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool { + return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient +} + // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration // // It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { - if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { +// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV) +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) { + if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) { return authenticateWithDeviceCodeFlow(ctx, config, hint) } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index b0d377c21..6d969bb12 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "") if err != nil { return err.Error() } diff --git a/client/server/server.go b/client/server/server.go index a930e8a02..49000c092 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -504,7 +504,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if msg.Hint != nil { hint = *msg.Hint } - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err @@ -1235,7 +1235,7 @@ func (s *Server) RequestJWTAuth( } isDesktop := isUnixRunningDesktop() - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint) if err != nil { return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) } From f31bba87b44c19472bed2162e3e189bf335127c9 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 26 Nov 2025 17:07:44 +0300 Subject: [PATCH 073/118] [management] Preserve validator settings on account settings update (#4862) --- management/server/account.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/management/server/account.go b/management/server/account.go index 3e498536c..716d5ab5d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -325,6 +325,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } + newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups + newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator + if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { return err } From 02200d790bb8a5466f3b824ea6c6c2e624240b4a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:06:47 +0100 Subject: [PATCH 074/118] [client] Open browser for ssh automatically (#4838) --- client/cmd/login.go | 12 +----------- client/cmd/ssh.go | 33 ++++++++++++++++++++++++++++++- client/ssh/client/client.go | 18 ++++++++++++----- client/ssh/common.go | 36 ++++++++++++++++++++++++++++------ client/ssh/proxy/proxy.go | 30 +++++++++++++++------------- client/ssh/proxy/proxy_test.go | 2 +- util/common.go | 15 +++++++++++++- 7 files changed, 107 insertions(+), 39 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 2ddcccc8a..a34bb7c70 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "os" - "os/exec" "os/user" "runtime" "strings" "time" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := openBrowser(verificationURIComplete); err != nil { + if err := util.OpenBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } -// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. -func openBrowser(url string) error { - if browser := os.Getenv("BROWSER"); browser != "" { - return exec.Command(browser, url).Start() - } - return open.Run(url) -} - // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 92857c637..525bcdef1 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -51,6 +51,7 @@ var ( identityFile string skipCachedToken bool requestPTY bool + sshNoBrowser bool ) var ( @@ -81,6 +82,7 @@ func init() { sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc) sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") @@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string { return defaultValue } +// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes +func getBoolEnvOrDefault(flagName string, defaultValue bool) bool { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + return defaultValue +} + // resetSSHGlobals sets SSH globals to their default values func resetSSHGlobals() { port = sshserver.DefaultSSHPort @@ -196,6 +213,7 @@ func resetSSHGlobals() { strictHostKeyChecking = true knownHostsFile = "" identityFile = "" + sshNoBrowser = false } // parseCustomSSHFlags extracts -L, -R flags and returns filtered args @@ -370,6 +388,7 @@ type sshFlags struct { KnownHostsFile string IdentityFile string SkipCachedToken bool + NoBrowser bool ConfigPath string LogLevel string LocalForwards []string @@ -381,6 +400,7 @@ type sshFlags struct { func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs.SetOutput(nil) @@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc) fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") @@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { knownHostsFile = flags.KnownHostsFile identityFile = flags.IdentityFile skipCachedToken = flags.SkipCachedToken + sshNoBrowser = flags.NoBrowser if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { configPath = flags.ConfigPath @@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { DaemonAddr: daemonAddr, SkipCachedToken: skipCachedToken, InsecureSkipVerify: !strictHostKeyChecking, + NoBrowser: sshNoBrowser, }) if err != nil { @@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid port: %s", portStr) } - proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) + // Check env var for browser setting since this command is invoked via SSH ProxyCommand + // where command-line flags cannot be passed. Default is to open browser. + noBrowser := getBoolEnvOrDefault("NO_BROWSER", false) + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener) if err != nil { return fmt.Errorf("create SSH proxy: %w", err) } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 31b80317a..aab222093 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util" ) const ( @@ -278,6 +279,7 @@ type DialOptions struct { DaemonAddr string SkipCachedToken bool InsecureSkipVerify bool + NoBrowser bool } // Dial connects to the given ssh server with specified options @@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er config.Auth = append(config.Auth, authMethod) } - return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser) } // dialSSH establishes an SSH connection without JWT authentication @@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig } // dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection -func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) { host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("parse address %s: %w", addr, err) @@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) defer cancel() - jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser) if err != nil { return nil, fmt.Errorf("request JWT token: %w", err) } @@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo } // requestJWTToken requests a JWT token from the NetBird daemon -func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) { hint := profilemanager.GetLoginHint() conn, err := connectToDaemon(daemonAddr) @@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st defer conn.Close() client := proto.NewDaemonServiceClient(conn) - return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) + + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener) } // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon diff --git a/client/ssh/common.go b/client/ssh/common.go index 3beb12806..6574437b5 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe return VerifyHostKey(storedKeyData, presentedKey, peerAddress) } +// printAuthInstructions prints authentication instructions to stderr +func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.") + _, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:") + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete) + + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") +} + // RequestJWTToken requests or retrieves a JWT token for SSH authentication -func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) { req := &proto.RequestJWTAuthRequest{} if hint != "" { req.Hint = &hint @@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo } if stderr != nil { - _, _ = fmt.Fprintln(stderr, "SSH authentication required.") - _, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) - if authResponse.UserCode != "" { - _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + printAuthInstructions(stderr, authResponse, openBrowser != nil) + } + + if openBrowser != nil { + if err := openBrowser(authResponse.VerificationURIComplete); err != nil { + log.Debugf("open browser: %v", err) } - _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") } tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index bc8a84b89..4e807e33c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -35,15 +35,16 @@ const ( ) type SSHProxy struct { - daemonAddr string - targetHost string - targetPort int - stderr io.Writer - conn *grpc.ClientConn - daemonClient proto.DaemonServiceClient + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient + browserOpener func(string) error } -func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP } return &SSHProxy{ - daemonAddr: daemonAddr, - targetHost: targetHost, - targetPort: targetPort, - stderr: stderr, - conn: grpcConn, - daemonClient: proto.NewDaemonServiceClient(grpcConn), + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + browserOpener: browserOpener, }, nil } @@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error { func (p *SSHProxy) Connect(ctx context.Context) error { hint := profilemanager.GetLoginHint() - jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener) if err != nil { return fmt.Errorf(jwtAuthErrorMsg, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index c5036da37..582f9c07b 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) { validToken := generateValidJWT(t, privateKey, issuer, audience) mockDaemon.setJWTToken(validToken) - proxyInstance, err := New(mockDaemon.addr, host, port, nil) + proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) require.NoError(t, err) clientConn, proxyConn := net.Pipe() diff --git a/util/common.go b/util/common.go index 27adb9d13..89903b609 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,19 @@ package util -import "os" +import ( + "os" + "os/exec" + + "github.com/skratchdot/open-golang/open" +) + +// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func OpenBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} // SliceDiff returns the elements in slice `x` that are not in slice `y` func SliceDiff(x, y []string) []string { From aca0398105fd0662c09533e1368a8682310efd94 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 26 Nov 2025 16:07:45 +0100 Subject: [PATCH 075/118] [client] Add excluded port range handling for PKCE flow (#4853) --- client/internal/auth/pkce_flow.go | 46 +++++- client/internal/auth/pkce_flow_other.go | 8 + client/internal/auth/pkce_flow_test.go | 147 ++++++++++++++++++ client/internal/auth/pkce_flow_windows.go | 86 ++++++++++ .../internal/auth/pkce_flow_windows_test.go | 116 ++++++++++++++ 5 files changed, 398 insertions(+), 5 deletions(-) create mode 100644 client/internal/auth/pkce_flow_other.go create mode 100644 client/internal/auth/pkce_flow_windows.go create mode 100644 client/internal/auth/pkce_flow_windows_test.go diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 48873f640..39da2a79f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" @@ -46,9 +47,10 @@ type PKCEAuthorizationFlow struct { func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string - // find the first available redirect URL + excludedRanges := getSystemExcludedPortRanges() + for _, redirectURL := range config.RedirectURLs { - if !isRedirectURLPortUsed(redirectURL) { + if !isRedirectURLPortUsed(redirectURL, excludedRanges) { availableRedirectURL = redirectURL break } @@ -282,15 +284,22 @@ func createCodeChallenge(codeVerifier string) string { return base64.RawURLEncoding.EncodeToString(sha2[:]) } -// isRedirectURLPortUsed checks if the port used in the redirect URL is in use. -func isRedirectURLPortUsed(redirectURL string) bool { +// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows. +func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool { parsedURL, err := url.Parse(redirectURL) if err != nil { log.Errorf("failed to parse redirect URL: %v", err) return true } - addr := fmt.Sprintf(":%s", parsedURL.Port()) + port := parsedURL.Port() + + if isPortInExcludedRange(port, excludedRanges) { + log.Warnf("port %s is in Windows excluded port range, skipping", port) + return true + } + + addr := fmt.Sprintf(":%s", port) conn, err := net.DialTimeout("tcp", addr, 3*time.Second) if err != nil { return false @@ -304,6 +313,33 @@ func isRedirectURLPortUsed(redirectURL string) bool { return true } +// excludedPortRange represents a range of excluded ports. +type excludedPortRange struct { + start int + end int +} + +// isPortInExcludedRange checks if the given port is in any of the excluded ranges. +func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool { + if len(excludedRanges) == 0 { + return false + } + + portNum, err := strconv.Atoi(port) + if err != nil { + log.Debugf("invalid port number %s: %v", port, err) + return false + } + + for _, r := range excludedRanges { + if portNum >= r.start && portNum <= r.end { + return true + } + } + + return false +} + func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) { tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl) if err != nil { diff --git a/client/internal/auth/pkce_flow_other.go b/client/internal/auth/pkce_flow_other.go new file mode 100644 index 000000000..96df41539 --- /dev/null +++ b/client/internal/auth/pkce_flow_other.go @@ -0,0 +1,8 @@ +//go:build !windows + +package auth + +// getSystemExcludedPortRanges returns nil on non-Windows platforms. +func getSystemExcludedPortRanges() []excludedPortRange { + return nil +} diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b2347d12d..380a360e5 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -2,8 +2,11 @@ package auth import ( "context" + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal" @@ -69,3 +72,147 @@ func TestPromptLogin(t *testing.T) { }) } } + +func TestIsPortInExcludedRange(t *testing.T) { + tests := []struct { + name string + port string + excludedRanges []excludedPortRange + expectedBlocked bool + }{ + { + name: "Port in excluded range", + port: "8080", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at start of range", + port: "8000", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at end of range", + port: "8100", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port before range", + port: "7999", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Port after range", + port: "8101", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty excluded ranges", + port: "8080", + excludedRanges: []excludedPortRange{}, + expectedBlocked: false, + }, + { + name: "Nil excluded ranges", + port: "8080", + excludedRanges: nil, + expectedBlocked: false, + }, + { + name: "Multiple ranges - port in second range", + port: "9050", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: true, + }, + { + name: "Multiple ranges - port not in any range", + port: "8500", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: false, + }, + { + name: "Invalid port string", + port: "invalid", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty port string", + port: "", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPortInExcludedRange(tt.port, tt.excludedRanges) + assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch") + }) + } +} + +func TestIsRedirectURLPortUsed(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + usedPort := listener.Addr().(*net.TCPAddr).Port + + tests := []struct { + name string + redirectURL string + excludedRanges []excludedPortRange + expectedUsed bool + }{ + { + name: "Port in excluded range", + redirectURL: "http://127.0.0.1:8080/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + { + name: "Port actually in use", + redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort), + excludedRanges: nil, + expectedUsed: true, + }, + { + name: "Port not in use and not excluded", + redirectURL: "http://127.0.0.1:65432/", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Invalid URL without port", + redirectURL: "not-a-valid-url", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Port excluded even if not in use", + redirectURL: "http://127.0.0.1:8050/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges) + assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch") + }) + } +} diff --git a/client/internal/auth/pkce_flow_windows.go b/client/internal/auth/pkce_flow_windows.go new file mode 100644 index 000000000..cf3f8718f --- /dev/null +++ b/client/internal/auth/pkce_flow_windows.go @@ -0,0 +1,86 @@ +//go:build windows + +package auth + +import ( + "bufio" + "fmt" + "os/exec" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh. +func getSystemExcludedPortRanges() []excludedPortRange { + ranges, err := getExcludedPortRangesFromNetsh() + if err != nil { + log.Debugf("failed to get Windows excluded port ranges: %v", err) + return nil + } + + return ranges +} + +// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command. +func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) { + cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("netsh command: %w", err) + } + + return parseExcludedPortRanges(string(output)) +} + +// parseExcludedPortRanges parses the output of the netsh command to extract port ranges. +func parseExcludedPortRanges(output string) ([]excludedPortRange, error) { + var ranges []excludedPortRange + scanner := bufio.NewScanner(strings.NewReader(output)) + + foundHeader := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") { + foundHeader = true + continue + } + + if !foundHeader { + continue + } + + if strings.Contains(line, "----------") { + continue + } + + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + startPort, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + endPort, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + + ranges = append(ranges, excludedPortRange{start: startPort, end: endPort}) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan output: %w", err) + } + + return ranges, nil +} diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go new file mode 100644 index 000000000..dd455b2fe --- /dev/null +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -0,0 +1,116 @@ +//go:build windows + +package auth + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestParseExcludedPortRanges(t *testing.T) { + tests := []struct { + name string + netshOutput string + expectedRanges []excludedPortRange + expectError bool + }{ + { + name: "Valid netsh output with multiple ranges", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 + +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 5357 5357 * + 50000 50059 * +`, + expectedRanges: []excludedPortRange{ + {start: 5357, end: 5357}, + {start: 50000, end: 50059}, + }, + expectError: false, + }, + { + name: "Empty output", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 +`, + expectedRanges: nil, + expectError: false, + }, + { + name: "Single range", + netshOutput: ` +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 8080 8090 +`, + expectedRanges: []excludedPortRange{ + {start: 8080, end: 8090}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseExcludedPortRanges(tt.netshOutput) + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedRanges, ranges) + } + }) + } +} + +func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { + ranges := getSystemExcludedPortRanges() + t.Logf("Found %d excluded port ranges on this system", len(ranges)) + + listener1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener1.Close() + }() + usedPort1 := listener1.Addr().(*net.TCPAddr).Port + + availablePort := 65432 + + config := internal.PKCEAuthProviderConfig{ + ClientID: "test-client-id", + Audience: "test-audience", + TokenEndpoint: "https://test-token-endpoint.com/token", + Scope: "openid email profile", + AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", + RedirectURLs: []string{ + fmt.Sprintf("http://127.0.0.1:%d/", usedPort1), + fmt.Sprintf("http://127.0.0.1:%d/", availablePort), + }, + UseIDToken: true, + } + + flow, err := NewPKCEAuthorizationFlow(config) + require.NoError(t, err) + require.NotNil(t, flow) + assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort), + "Should skip port in use and select available port") +} From ddcd182859116b7ac30a7934fcc7e40fba818591 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 28 Nov 2025 17:26:22 +0100 Subject: [PATCH 076/118] [client] Sleep detection on macOS (#4859) A macOS-specific sleep detection mechanism using IOKit and CoreFoundation via cgo is introduced, with a fallback implementation for unsupported platforms. A public Service wrapper provides an event-driven API translating system sleep/wake events into gRPC calls. The UI client integrates sleep detection to manage connectivity state based on system sleep status. --- client/internal/sleep/detector_darwin.go | 218 ++++++++++++++++++ .../internal/sleep/detector_notsupported.go | 9 + client/internal/sleep/service.go | 36 +++ client/proto/daemon.pb.go | 2 +- client/ui/client_ui.go | 69 +++++- 5 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 client/internal/sleep/detector_darwin.go create mode 100644 client/internal/sleep/detector_notsupported.go create mode 100644 client/internal/sleep/service.go diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go new file mode 100644 index 000000000..3d6747ed1 --- /dev/null +++ b/client/internal/sleep/detector_darwin.go @@ -0,0 +1,218 @@ +//go:build darwin && !ios + +package sleep + +/* +#cgo LDFLAGS: -framework IOKit -framework CoreFoundation +#include +#include +#include + +extern void sleepCallbackBridge(); +extern void poweredOnCallbackBridge(); +extern void suspendedCallbackBridge(); +extern void resumedCallbackBridge(); + + +// C global variables for IOKit state +static IONotificationPortRef g_notifyPortRef = NULL; +static io_object_t g_notifierObject = 0; +static io_object_t g_generalInterestNotifier = 0; +static io_connect_t g_rootPort = 0; +static CFRunLoopRef g_runLoop = NULL; + +static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) { + switch (messageType) { + case kIOMessageSystemWillSleep: + sleepCallbackBridge(); + IOAllowPowerChange(g_rootPort, (long)messageArgument); + break; + case kIOMessageSystemHasPoweredOn: + poweredOnCallbackBridge(); + break; + case kIOMessageServiceIsSuspended: + suspendedCallbackBridge(); + break; + case kIOMessageServiceIsResumed: + resumedCallbackBridge(); + break; + default: + break; + } +} + +static void registerNotifications() { + g_rootPort = IORegisterForSystemPower( + NULL, + &g_notifyPortRef, + (IOServiceInterestCallback)sleepCallback, + &g_notifierObject + ); + + if (g_rootPort == 0) { + return; + } + + CFRunLoopAddSource(CFRunLoopGetCurrent(), + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + g_runLoop = CFRunLoopGetCurrent(); + CFRunLoopRun(); +} + +static void unregisterNotifications() { + CFRunLoopRemoveSource(g_runLoop, + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + IODeregisterForSystemPower(&g_notifierObject); + IOServiceClose(g_rootPort); + IONotificationPortDestroy(g_notifyPortRef); + CFRunLoopStop(g_runLoop); + + g_notifyPortRef = NULL; + g_notifierObject = 0; + g_rootPort = 0; + g_runLoop = NULL; +} + +*/ +import "C" + +import ( + "context" + "fmt" + "runtime" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + serviceRegistry = make(map[*Detector]struct{}) + serviceRegistryMu sync.Mutex +) + +//export sleepCallbackBridge +func sleepCallbackBridge() { + log.Info("sleepCallbackBridge event triggered") + + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeSleep) + } +} + +//export resumedCallbackBridge +func resumedCallbackBridge() { + log.Info("resumedCallbackBridge event triggered") +} + +//export suspendedCallbackBridge +func suspendedCallbackBridge() { + log.Info("suspendedCallbackBridge event triggered") +} + +//export poweredOnCallbackBridge +func poweredOnCallbackBridge() { + log.Info("poweredOnCallbackBridge event triggered") + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeWakeUp) + } +} + +type Detector struct { + callback func(event EventType) + ctx context.Context + cancel context.CancelFunc +} + +func NewDetector() (*Detector, error) { + return &Detector{}, nil +} + +func (d *Detector) Register(callback func(event EventType)) error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + if _, exists := serviceRegistry[d]; exists { + return fmt.Errorf("detector service already registered") + } + + d.callback = callback + + d.ctx, d.cancel = context.WithCancel(context.Background()) + + if len(serviceRegistry) > 0 { + serviceRegistry[d] = struct{}{} + return nil + } + + serviceRegistry[d] = struct{}{} + + // CFRunLoop must run on a single fixed OS thread + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + C.registerNotifications() + }() + + log.Info("sleep detection service started on macOS") + return nil +} + +// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down +// and the runloop is stopped and cleaned up. +func (d *Detector) Deregister() error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + _, exists := serviceRegistry[d] + if !exists { + return nil + } + + // cancel and remove this detector + d.cancel() + delete(serviceRegistry, d) + + // If other Detectors still exist, leave IOKit running + if len(serviceRegistry) > 0 { + return nil + } + + log.Info("sleep detection service stopping (deregister)") + + // Deregister IOKit notifications, stop runloop, and free resources + C.unregisterNotifications() + + return nil +} + +func (d *Detector) triggerCallback(event EventType) { + doneChan := make(chan struct{}) + + timeout := time.NewTimer(500 * time.Millisecond) + defer timeout.Stop() + + cb := d.callback + go func(callback func(event EventType)) { + log.Info("sleep detection event fired") + callback(event) + close(doneChan) + }(cb) + + select { + case <-doneChan: + case <-d.ctx.Done(): + case <-timeout.C: + log.Warnf("sleep callback timed out") + } +} diff --git a/client/internal/sleep/detector_notsupported.go b/client/internal/sleep/detector_notsupported.go new file mode 100644 index 000000000..6323bf5d1 --- /dev/null +++ b/client/internal/sleep/detector_notsupported.go @@ -0,0 +1,9 @@ +//go:build !darwin || ios + +package sleep + +import "fmt" + +func NewDetector() (detector, error) { + return nil, fmt.Errorf("sleep not supported on this platform") +} diff --git a/client/internal/sleep/service.go b/client/internal/sleep/service.go new file mode 100644 index 000000000..35fc933c0 --- /dev/null +++ b/client/internal/sleep/service.go @@ -0,0 +1,36 @@ +package sleep + +var ( + EventTypeSleep EventType = 0 + EventTypeWakeUp EventType = 1 +) + +type EventType int + +type detector interface { + Register(callback func(eventType EventType)) error + Deregister() error +} + +type Service struct { + detector detector +} + +func New() (*Service, error) { + d, err := NewDetector() + if err != nil { + return nil, err + } + + return &Service{ + detector: d, + }, nil +} + +func (s *Service) Register(callback func(eventType EventType)) error { + return s.detector.Register(callback) +} + +func (s *Service) Deregister() error { + return s.detector.Deregister() +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 7b9ae25f7..6f8255615 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v3.21.12 // source: daemon.proto package proto diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 44643616d..57d0e74a0 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" @@ -209,10 +210,11 @@ var iconConnectedDot []byte var iconDisconnectedDot []byte type serviceClient struct { - ctx context.Context - cancel context.CancelFunc - addr string - conn proto.DaemonServiceClient + ctx context.Context + cancel context.CancelFunc + addr string + conn proto.DaemonServiceClient + connLock sync.Mutex eventHandler *eventHandler @@ -1098,6 +1100,9 @@ func (s *serviceClient) onTrayReady() { go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) + + // Start sleep detection listener + go s.startSleepListener() } func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { @@ -1134,6 +1139,8 @@ func (s *serviceClient) onTrayExit() { // getSrvClient connection to the service. func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) { + s.connLock.Lock() + defer s.connLock.Unlock() if s.conn != nil { return s.conn, nil } @@ -1156,6 +1163,60 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } +// startSleepListener initializes the sleep detection service and listens for sleep events +func (s *serviceClient) startSleepListener() { + sleepService, err := sleep.New() + if err != nil { + log.Warnf("%v", err) + return + } + + if err := sleepService.Register(s.handleSleepEvents); err != nil { + log.Errorf("failed to start sleep detection: %v", err) + return + } + + log.Info("sleep detection service initialized") + + // Cleanup on context cancellation + go func() { + <-s.ctx.Done() + log.Info("stopping sleep event listener") + if err := sleepService.Deregister(); err != nil { + log.Errorf("failed to deregister sleep detection: %v", err) + } + }() +} + +// handleSleepEvents sends a sleep notification to the daemon via gRPC +func (s *serviceClient) handleSleepEvents(event sleep.EventType) { + conn, err := s.getSrvClient(0) + if err != nil { + log.Errorf("failed to get daemon client for sleep notification: %v", err) + return + } + + switch event { + case sleep.EventTypeWakeUp: + log.Infof("handle wakeup event: %v", event) + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + return + case sleep.EventTypeSleep: + log.Infof("handle sleep event: %v", event) + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + return + } + } + + log.Info("successfully notified daemon about sleep/wakeup event") +} + // setSettingsEnabled enables or disables the settings menu based on the provided state func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { From cb83b7c0d3ddbacdd6ff391bdcc405f8b2ad00f0 Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:53:53 +0100 Subject: [PATCH 077/118] [relay] use exposed address for healthcheck TLS validation (#4872) * fix(relay): use exposed address for healthcheck TLS validation Healthcheck was using listen address (0.0.0.0) instead of exposed address (domain name) for certificate validation, causing validation to always fail. Now correctly uses the exposed address where the TLS certificate is valid, matching real client connection behavior. * - store exposedAddress directly in Relay struct instead of parsing on every call - remove unused parseHostPort() function - remove unused ListenAddress() method from ServiceChecker interface - improve error logging with address context * [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses. * [relay/healthcheck] Fix certificate validation status during health check --------- Co-authored-by: Maycon Santos --- relay/healthcheck/healthcheck.go | 44 ++++++++++++-------------------- relay/healthcheck/quic.go | 31 ---------------------- relay/healthcheck/ws.go | 12 +++++++-- relay/server/relay.go | 27 ++++++++++++-------- relay/server/server.go | 8 ++---- 5 files changed, 46 insertions(+), 76 deletions(-) delete mode 100644 relay/healthcheck/quic.go diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index eedd62394..6463843eb 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,14 +6,13 @@ import ( "errors" "net" "net/http" + "strings" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" - "github.com/netbirdio/netbird/relay/server/listener/quic" - "github.com/netbirdio/netbird/relay/server/listener/ws" ) const ( @@ -27,7 +26,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ListenAddress() string + ExposedAddress() string } type HealthStatus struct { @@ -135,7 +134,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if ok := s.validateCertificate(ctx); !ok { + if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") { + status.CertificateValid = false + } + + if ok := s.validateConnection(ctx); !ok { status.Status = statusUnhealthy status.CertificateValid = false healthy = false @@ -152,32 +155,18 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { return listeners, true } -func (s *Server) validateCertificate(ctx context.Context) bool { - listenAddress := s.config.ServiceChecker.ListenAddress() - if listenAddress == "" { - log.Warn("listen address is empty") +func (s *Server) validateConnection(ctx context.Context) bool { + exposedAddress := s.config.ServiceChecker.ExposedAddress() + if exposedAddress == "" { + log.Error("exposed address is empty, cannot validate certificate") return false } - dAddr := dialAddress(listenAddress) - - for _, proto := range s.config.ServiceChecker.ListenerProtocols() { - switch proto { - case ws.Proto: - if err := dialWS(ctx, dAddr); err != nil { - log.Errorf("failed to dial WebSocket listener: %v", err) - return false - } - case quic.Proto: - if err := dialQUIC(ctx, dAddr); err != nil { - log.Errorf("failed to dial QUIC listener: %v", err) - return false - } - default: - log.Warnf("unknown protocol for healthcheck: %s", proto) - return false - } + if err := dialWS(ctx, exposedAddress); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err) + return false } + return true } @@ -187,8 +176,9 @@ func dialAddress(listenAddress string) string { return listenAddress // fallback, might be invalid for dialing } + // When listening on all interfaces, show localhost for better readability if host == "" || host == "::" || host == "0.0.0.0" { - host = "0.0.0.0" + host = "localhost" } return net.JoinHostPort(host, port) diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go deleted file mode 100644 index 1582edf7b..000000000 --- a/relay/healthcheck/quic.go +++ /dev/null @@ -1,31 +0,0 @@ -package healthcheck - -import ( - "context" - "crypto/tls" - "fmt" - "time" - - "github.com/quic-go/quic-go" - - tlsnb "github.com/netbirdio/netbird/shared/relay/tls" -) - -func dialQUIC(ctx context.Context, address string) error { - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, // Keep certificate validation enabled - NextProtos: []string{tlsnb.NBalpn}, - } - - conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{ - MaxIdleTimeout: 30 * time.Second, - KeepAlivePeriod: 10 * time.Second, - EnableDatagrams: true, - }) - if err != nil { - return fmt.Errorf("failed to connect to QUIC server: %w", err) - } - - _ = conn.CloseWithError(0, "availability check complete") - return nil -} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index 49694356c..badd31219 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,6 +3,7 @@ package healthcheck import ( "context" "fmt" + "strings" "github.com/coder/websocket" @@ -10,12 +11,19 @@ import ( ) func dialWS(ctx context.Context, address string) error { - url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath) + addressSplit := strings.Split(address, "/") + scheme := "ws" + if addressSplit[0] == "rels:" { + scheme = "wss" + } + url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath) conn, resp, err := websocket.Dial(ctx, url, nil) if resp != nil { defer func() { - _ = resp.Body.Close() + if resp.Body != nil { + _ = resp.Body.Close() + } }() } diff --git a/relay/server/relay.go b/relay/server/relay.go index d86684937..aab575bf0 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -51,10 +51,11 @@ type Relay struct { metricsCancel context.CancelFunc validator Validator - store *store.Store - notifier *store.PeerNotifier - instanceURL string - preparedMsg *preparedMsg + store *store.Store + notifier *store.PeerNotifier + instanceURL string + exposedAddress string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -87,12 +88,13 @@ func NewRelay(config Config) (*Relay, error) { } r := &Relay{ - metrics: m, - metricsCancel: metricsCancel, - validator: config.AuthValidator, - instanceURL: config.instanceURL, - store: store.NewStore(), - notifier: store.NewPeerNotifier(), + metrics: m, + metricsCancel: metricsCancel, + validator: config.AuthValidator, + instanceURL: config.instanceURL, + exposedAddress: config.ExposedAddress, + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -178,3 +180,8 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } + +// ExposedAddress returns the exposed address (domain:port) where clients connect +func (r *Relay) ExposedAddress() string { + return r.exposedAddress +} diff --git a/relay/server/server.go b/relay/server/server.go index 4c30e7fdc..2c9e658d6 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -28,8 +28,6 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - listenAddr string - relay *Relay listeners []listener.Listener listenerMux sync.Mutex @@ -62,8 +60,6 @@ func NewServer(config Config) (*Server, error) { // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.listenAddr = cfg.Address - wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, @@ -139,6 +135,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ListenAddress() string { - return r.listenAddr +func (r *Server) ExposedAddress() string { + return r.relay.ExposedAddress() } From e47d815dd27526a04e54c212161152ba5a2a8dee Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 1 Dec 2025 14:16:03 +0100 Subject: [PATCH 078/118] Fix IsAnotherProcessRunning (#4858) Compare the exact process name rather than searching for a substring of the full path --- client/ui/process/process.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/ui/process/process.go b/client/ui/process/process.go index d0ef54896..28276f416 100644 --- a/client/ui/process/process.go +++ b/client/ui/process/process.go @@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) { continue } - if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) { + runningProcessName := strings.ToLower(filepath.Base(runningProcessPath)) + if runningProcessName == processName && isProcessOwnedByCurrentUser(p) { return p.Pid, true, nil } } From 387d43bcc18eb51baf619005f11bb449cb1514f1 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 1 Dec 2025 14:25:52 +0100 Subject: [PATCH 079/118] [client, management] Add OAuth select_account prompt support to PKCE flow (#4880) * Add OAuth select_account prompt support to PKCE flow Extends LoginFlag enum with select_account options to enable multi-account selection during authentication. This allows users to choose which account to use when multiple accounts have active sessions with the identity provider. The new flags are backward compatible - existing LoginFlag values (0=prompt login, 1=max_age=0) retain their original behavior. --- client/internal/auth/pkce_flow.go | 10 ++++--- client/internal/auth/pkce_flow_test.go | 35 ++++++++++++++---------- shared/management/client/common/types.go | 25 +++++++++-------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 39da2a79f..c03376f5b 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" + "github.com/netbirdio/netbird/shared/management/client/common" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -104,11 +105,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), } if !p.providerConfig.DisablePromptLogin { - if p.providerConfig.LoginFlag.IsPromptLogin() { - params = append(params, oauth2.SetAuthURLParam("prompt", "login")) - } - if p.providerConfig.LoginFlag.IsMaxAge0Login() { + switch p.providerConfig.LoginFlag { + case common.LoginFlagPromptLogin: + params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) + case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) + params = append(params, oauth2.SetAuthURLParam("prompt", "select_account")) } } if p.providerConfig.LoginHint != "" { diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index 380a360e5..b5843f104 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -15,30 +15,37 @@ import ( func TestPromptLogin(t *testing.T) { const ( - promptLogin = "prompt=login" - maxAge0 = "max_age=0" + promptSelectAccountLogin = "prompt=login+select_account" + promptSelectAccount = "prompt=select_account" + maxAge0 = "max_age=0" ) tt := []struct { name string loginFlag mgm.LoginFlag disablePromptLogin bool - expect string + expectContains []string }{ { - name: "Prompt login", - loginFlag: mgm.LoginFlagPrompt, - expect: promptLogin, + name: "Prompt login with select account", + loginFlag: mgm.LoginFlagPromptLogin, + expectContains: []string{promptSelectAccountLogin}, }, { - name: "Max age 0 login", - loginFlag: mgm.LoginFlagMaxAge0, - expect: maxAge0, + name: "Max age 0 with select account", + loginFlag: mgm.LoginFlagMaxAge0, + expectContains: []string{maxAge0, promptSelectAccount}, }, { name: "Disable prompt login", - loginFlag: mgm.LoginFlagPrompt, + loginFlag: mgm.LoginFlagPromptLogin, disablePromptLogin: true, + expectContains: []string{}, + }, + { + name: "None flag should not add parameters", + loginFlag: mgm.LoginFlagNone, + expectContains: []string{}, }, } @@ -53,6 +60,7 @@ func TestPromptLogin(t *testing.T) { RedirectURLs: []string{"http://127.0.0.1:33992/"}, UseIDToken: true, LoginFlag: tc.loginFlag, + DisablePromptLogin: tc.disablePromptLogin, } pkce, err := NewPKCEAuthorizationFlow(config) if err != nil { @@ -63,11 +71,8 @@ func TestPromptLogin(t *testing.T) { t.Fatalf("Failed to request auth info: %v", err) } - if !tc.disablePromptLogin { - require.Contains(t, authInfo.VerificationURIComplete, tc.expect) - } else { - require.Contains(t, authInfo.VerificationURIComplete, promptLogin) - require.NotContains(t, authInfo.VerificationURIComplete, maxAge0) + for _, expected := range tc.expectContains { + require.Contains(t, authInfo.VerificationURIComplete, expected) } }) } diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 699617574..550bcde30 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -1,19 +1,20 @@ package common -// LoginFlag introduces additional login flags to the PKCE authorization request +// LoginFlag introduces additional login flags to the PKCE authorization request. +// +// # Config Values +// +// | Value | Flag | OAuth Parameters | +// |-------|----------------------|-----------------------------------------| +// | 0 | LoginFlagPromptLogin | prompt=select_account login | +// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | type LoginFlag uint8 const ( - // LoginFlagPrompt adds prompt=login to the authorization request - LoginFlagPrompt LoginFlag = iota - // LoginFlagMaxAge0 adds max_age=0 to the authorization request + // LoginFlagPromptLogin adds prompt=select_account login to the authorization request + LoginFlagPromptLogin LoginFlag = iota + // LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request LoginFlagMaxAge0 + // LoginFlagNone disables all login flags + LoginFlagNone ) - -func (l LoginFlag) IsPromptLogin() bool { - return l == LoginFlagPrompt -} - -func (l LoginFlag) IsMaxAge0Login() bool { - return l == LoginFlagMaxAge0 -} From 4b77359042c65d0d6cc4b7f19e76d4bf85de137a Mon Sep 17 00:00:00 2001 From: Fahri Shihab <22738442+fahrishih@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:57:42 +0700 Subject: [PATCH 080/118] [management] Groups API with name query parameter (#4831) --- .../http/handlers/groups/groups_handler.go | 23 +++++ .../handlers/groups/groups_handler_test.go | 91 ++++++++++++++++++- shared/management/client/rest/groups.go | 25 +++++ shared/management/http/api/openapi.yml | 10 ++ shared/management/http/api/types.gen.go | 6 ++ 5 files changed, 154 insertions(+), 1 deletion(-) diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 208a2e828..56ccc9d0b 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } accountID, userID := userAuth.AccountId, userAuth.UserId + // Check if filtering by name + groupName := r.URL.Query().Get("name") + if groupName != "" { + // Get single group by name + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + // Return as array with single element to maintain API consistency + groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)} + util.WriteJSONObject(r.Context(), w, groupsResponse) + return + } + + // Get all groups groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index b7dd3944a..458a15c11 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -60,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return group, nil }, + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + groups := []*types.Group{ + {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + {ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, + } + + groups = append(groups, initGroups...) + + return groups, nil + }, GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } - return nil, fmt.Errorf("unknown group name") + return nil, status.Errorf(status.NotFound, "unknown group name") }, GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { return maps.Values(TestPeers), nil @@ -287,6 +298,84 @@ func TestWriteGroup(t *testing.T) { } } +func TestGetAllGroups(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + expectedCount int + }{ + { + name: "Get All Groups", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups", + expectedStatus: http.StatusOK, + expectedCount: 3, // id-jwt-group, id-existed, id-all + }, + { + name: "Get Group By Name - Existing", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups?name=All", + expectedStatus: http.StatusOK, + expectedCount: 1, + }, + { + name: "Get Group By Name - Not Found", + expectedBody: false, + requestType: http.MethodGet, + requestPath: "/api/groups?name=NonExistent", + expectedStatus: http.StatusNotFound, + }, + } + + p := initGroupTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) + + router := mux.NewRouter() + router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + return + } + + if !tc.expectedBody { + return + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var groups []api.Group + if err = json.Unmarshal(content, &groups); err != nil { + t.Fatalf("Response is not in correct json format; %v", err) + } + + assert.Equal(t, tc.expectedCount, len(groups)) + }) + } +} + func TestDeleteGroup(t *testing.T) { tt := []struct { name string diff --git a/shared/management/client/rest/groups.go b/shared/management/client/rest/groups.go index af068e077..7cd9535dd 100644 --- a/shared/management/client/rest/groups.go +++ b/shared/management/client/rest/groups.go @@ -4,10 +4,14 @@ import ( "bytes" "context" "encoding/json" + "errors" "github.com/netbirdio/netbird/shared/management/http/api" ) +// ErrGroupNotFound is returned when a group is not found +var ErrGroupNotFound = errors.New("group not found") + // GroupsAPI APIs for Groups, do not use directly type GroupsAPI struct { c *Client @@ -27,6 +31,27 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { return ret, err } +// GetByName get group by name +// See more: https://docs.netbird.io/api/resources/groups#list-all-groups +func (a *GroupsAPI) GetByName(ctx context.Context, groupName string) (*api.Group, error) { + params := map[string]string{"name": groupName} + resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, params) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.Group](resp) + if err != nil { + return nil, err + } + if len(ret) == 0 { + return nil, ErrGroupNotFound + } + return &ret[0], nil +} + // Get get group info // See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) { diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 4a5454002..2d063a7b5 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3362,6 +3362,14 @@ paths: security: - BearerAuth: [ ] - TokenAuth: [ ] + parameters: + - in: query + name: name + required: false + schema: + type: string + description: Filter groups by name (exact match) + example: "devs" responses: '200': description: A JSON Array of Groups @@ -3375,6 +3383,8 @@ paths: "$ref": "#/components/responses/bad_request" '401': "$ref": "#/components/responses/requires_authentication" + '404': + "$ref": "#/components/responses/not_found" '403': "$ref": "#/components/responses/forbidden" '500': diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 9611d26d6..d3e425548 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1908,6 +1908,12 @@ type GetApiEventsNetworkTrafficParamsConnectionType string // GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParamsDirection string +// GetApiGroupsParams defines parameters for GetApiGroups. +type GetApiGroupsParams struct { + // Name Filter groups by name (exact match) + Name *string `form:"name,omitempty" json:"name,omitempty"` +} + // GetApiPeersParams defines parameters for GetApiPeers. type GetApiPeersParams struct { // Name Filter peers by name From 52948ccd617eb8e95ac9860865129f9bf1be7af1 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 2 Dec 2025 14:17:59 +0300 Subject: [PATCH 081/118] [management] Add user created activity event (#4893) --- management/server/activity/codes.go | 2 ++ management/server/user.go | 24 +++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 5c5989f84..2e3be1ef5 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -179,6 +179,7 @@ const ( PeerIPUpdated Activity = 88 UserApproved Activity = 89 UserRejected Activity = 90 + UserCreated Activity = 91 AccountDeleted Activity = 99999 ) @@ -288,6 +289,7 @@ var activityMap = map[Activity]Code{ PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, UserApproved: {"User approved", "user.approve"}, UserRejected: {"User rejected", "user.reject"}, + UserCreated: {"User created", "user.create"}, } // StringCode returns a string code of the activity diff --git a/management/server/user.go b/management/server/user.go index 6b8bcbcad..cefc4d1a5 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -596,9 +596,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, isNewUser bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() + if isNewUser { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.UserCreated, nil) + }) + } + if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { @@ -661,7 +667,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) + oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) if err != nil { return false, nil, nil, nil, err } @@ -716,30 +722,30 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. -func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, bool, error) { existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + return nil, false, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) } update.AccountID = accountID - return update, nil // use all fields from update if addIfNotExists is true + return update, true, nil // use all fields from update if addIfNotExists is true } - return nil, err + return nil, false, err } if existingUser.AccountID != accountID { - return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch") + return nil, false, status.Errorf(status.InvalidArgument, "user account ID mismatch") } - return existingUser, nil + return existingUser, false, nil } func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { From 7193bd2da7bed767d178c3adaf6f864120f7ca51 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:34:28 +0100 Subject: [PATCH 082/118] [management] Refactor network map controller (#4789) --- client/cmd/testutil_test.go | 13 +- client/internal/engine_test.go | 13 +- client/server/server_test.go | 13 +- go.mod | 2 +- go.sum | 4 +- .../network_map/controller/controller.go | 160 ++++++++++------- .../network_map/controller/repository.go | 10 ++ .../controllers/network_map/interface.go | 14 +- .../controllers/network_map/interface_mock.go | 89 ++++++---- .../modules}/peers/ephemeral/interface.go | 5 + .../peers/ephemeral/manager/ephemeral.go | 37 ++-- .../peers/ephemeral/manager/ephemeral_test.go | 115 ++++++++++--- management/internals/modules/peers/manager.go | 162 ++++++++++++++++++ .../modules}/peers/manager_mock.go | 53 ++++++ management/internals/server/boot.go | 30 ++-- management/internals/server/controllers.go | 38 ++-- management/internals/server/modules.go | 25 +-- management/internals/server/server.go | 20 +-- management/internals/shared/grpc/server.go | 114 +++++++----- .../internals/shared/grpc/server_test.go | 8 +- management/internals/shared/grpc/token_mgr.go | 18 +- .../internals/shared/grpc/token_mgr_test.go | 9 +- management/server/account.go | 13 +- management/server/account/manager.go | 2 - management/server/account_test.go | 10 +- management/server/cache/idp.go | 1 + management/server/cache/idp_test.go | 2 +- management/server/cache/store.go | 16 +- management/server/cache/store_test.go | 6 +- management/server/dns_test.go | 4 +- management/server/http/handler.go | 2 +- .../http/handlers/peers/peers_handler.go | 28 +-- .../http/handlers/peers/peers_handler_test.go | 34 +--- .../testing/testing_tools/channel/channel.go | 5 +- management/server/management_proto_test.go | 16 +- management/server/management_test.go | 12 +- management/server/mock_server/account_mock.go | 6 - management/server/nameserver_test.go | 4 +- management/server/peer.go | 36 ++-- management/server/peer_test.go | 11 +- management/server/peers/manager.go | 68 -------- management/server/route_test.go | 4 +- management/server/user.go | 37 ++-- management/server/user_test.go | 29 +++- shared/management/client/client_test.go | 13 +- 45 files changed, 819 insertions(+), 492 deletions(-) rename management/{server => internals/modules}/peers/ephemeral/interface.go (83%) rename management/{server => internals/modules}/peers/ephemeral/manager/ephemeral.go (85%) rename management/{server => internals/modules}/peers/ephemeral/manager/ephemeral_test.go (69%) create mode 100644 management/internals/modules/peers/manager.go rename management/{server => internals/modules}/peers/manager_mock.go (55%) delete mode 100644 management/server/peers/manager.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 5477653a2..b9ff35945 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" @@ -24,8 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9252ce13e..5ab21e3e1 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -30,11 +30,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" @@ -54,7 +55,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) - networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index f8592bc7a..5f28a2664 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -17,11 +17,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -35,7 +36,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) - networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index a72d09dc3..91e587c32 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 011a5f199..98d395ad1 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 49bb9cef3..022ea774c 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -19,6 +19,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -42,6 +43,7 @@ type Controller struct { accountManagerMetrics *telemetry.AccountManagerMetrics peersUpdateManager network_map.PeersUpdateManager settingsManager settings.Manager + EphemeralPeersManager ephemeral.Manager accountUpdateLocks sync.Map sendAccountUpdateLocks sync.Map @@ -70,7 +72,7 @@ type bufferUpdate struct { var _ network_map.Controller = (*Controller)(nil) -func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller { +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller { nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) if err != nil { log.Fatal(fmt.Errorf("error creating metrics: %w", err)) @@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App dnsDomain: dnsDomain, config: config, - proxyController: proxyController, + proxyController: proxyController, + EphemeralPeersManager: ephemeralPeersManager, holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, @@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App } } +func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err) + } + + c.EphemeralPeersManager.OnPeerConnected(ctx, peer) + + return c.peersUpdateManager.CreateChannel(ctx, peerID), nil +} + +func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) { + c.peersUpdateManager.CloseChannel(ctx, peerID) + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err) + return + } + c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) +} + +func (c *Controller) CountStreams() int { + return c.peersUpdateManager.CountStreams() +} + func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) var ( @@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str return nil } -func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { - network, err := c.repo.GetAccountNetwork(ctx, accountId) - if err != nil { - return err - } - - peers, err := c.repo.GetAccountPeers(ctx, accountId) - if err != nil { - return err - } - - dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) - c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - c.peersUpdateManager.CloseChannel(ctx, peerId) - return nil -} - func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if isRequiresApproval { network, err := c.repo.GetAccountNetwork(ctx, accountID) @@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t return false, nil } -func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { - c.UpdatePeerInNetworkMapCache(accountId, peer) - _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { + peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("failed to get peers by ids: %w", err) + } + + for _, peer := range peers { + c.UpdatePeerInNetworkMapCache(accountID, peer) + } + + err = c.bufferSendUpdateAccountPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) + } + + return nil } -func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } +func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + for _, peerID := range peerIDs { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } - err = c.onPeerAddedUpdNetworkMapCache(account, peerID) - if err != nil { - return err + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } } } return c.bufferSendUpdateAccountPeers(ctx, accountID) } -func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - return err +func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountID) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + for _, peerID := range peerIDs { + c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerID) + + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) + continue + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) + continue + } } } @@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return networkMap, nil } -func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { c.peersUpdateManager.CloseChannels(ctx, peerIDs) } - -func (c *Controller) IsConnected(peerID string) bool { - return c.peersUpdateManager.HasChannel(peerID) -} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go index 44144263b..3ed51a5c3 100644 --- a/management/internals/controllers/network_map/controller/repository.go +++ b/management/internals/controllers/network_map/controller/repository.go @@ -12,6 +12,8 @@ type Repository interface { GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) + GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) } type repository struct { @@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]* func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { return r.store.GetAccountByPeerID(ctx, peerID) } + +func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { + return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs) +} + +func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { + return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 6f893ce79..b1de7d017 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -28,12 +28,12 @@ type Controller interface { GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + CountStreams() int - DeletePeer(ctx context.Context, accountId string, peerId string) error - - OnPeerUpdated(accountId string, peer *nbpeer.Peer) - OnPeerAdded(ctx context.Context, accountID string, peerID string) error - OnPeerDeleted(ctx context.Context, accountID string, peerID string) error - DisconnectPeers(ctx context.Context, peerIDs []string) - IsConnected(peerID string) bool + OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error + OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error + OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error + DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) + OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerID string) } diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index aaa093e47..5a98eefa8 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) } -// DeletePeer mocks base method. -func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { +// CountStreams mocks base method. +func (m *MockController) CountStreams() int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "CountStreams") + ret0, _ := ret[0].(int) return ret0 } -// DeletePeer indicates an expected call of DeletePeer. -func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { +// CountStreams indicates an expected call of CountStreams. +func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams)) } // DisconnectPeers mocks base method. -func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) + m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs) } // DisconnectPeers indicates an expected call of DisconnectPeers. -func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs) } // GetDNSDomain mocks base method. @@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) } -// IsConnected mocks base method. -func (m *MockController) IsConnected(peerID string) bool { +// OnPeerConnected mocks base method. +func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsConnected", peerID) - ret0, _ := ret[0].(bool) - return ret0 + ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID) + ret0, _ := ret[0].(chan *UpdateMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// IsConnected indicates an expected call of IsConnected. -func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { +// OnPeerConnected indicates an expected call of OnPeerConnected. +func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID) } -// OnPeerAdded mocks base method. -func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { +// OnPeerDisconnected mocks base method. +func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID) +} + +// OnPeerDisconnected indicates an expected call of OnPeerDisconnected. +func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID) +} + +// OnPeersAdded mocks base method. +func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerAdded indicates an expected call of OnPeerAdded. -func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { +// OnPeersAdded indicates an expected call of OnPeersAdded. +func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs) } -// OnPeerDeleted mocks base method. -func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { +// OnPeersDeleted mocks base method. +func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerDeleted indicates an expected call of OnPeerDeleted. -func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { +// OnPeersDeleted indicates an expected call of OnPeersDeleted. +func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs) } -// OnPeerUpdated mocks base method. -func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { +// OnPeersUpdated mocks base method. +func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) + ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs) + ret0, _ := ret[0].(error) + return ret0 } -// OnPeerUpdated indicates an expected call of OnPeerUpdated. -func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { +// OnPeersUpdated indicates an expected call of OnPeersUpdated. +func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs) } // StartWarmup mocks base method. diff --git a/management/server/peers/ephemeral/interface.go b/management/internals/modules/peers/ephemeral/interface.go similarity index 83% rename from management/server/peers/ephemeral/interface.go rename to management/internals/modules/peers/ephemeral/interface.go index a1605b3b9..8fe25435c 100644 --- a/management/server/peers/ephemeral/interface.go +++ b/management/internals/modules/peers/ephemeral/interface.go @@ -2,10 +2,15 @@ package ephemeral import ( "context" + "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) +const ( + EphemeralLifeTime = 10 * time.Minute +) + type Manager interface { LoadInitialPeers(ctx context.Context) Stop() diff --git a/management/server/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go similarity index 85% rename from management/server/peers/ephemeral/manager/ephemeral.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral.go index 062ba69d2..15119045b 100644 --- a/management/server/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -7,14 +7,15 @@ import ( log "github.com/sirupsen/logrus" - nbAccount "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" ) const ( - ephemeralLifeTime = 10 * time.Minute // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. cleanupWindow = 1 * time.Minute ) @@ -33,11 +34,11 @@ type ephemeralPeer struct { // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // in worst case we will get invalid error message in this manager. -// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store store.Store - accountManager nbAccount.Manager + store store.Store + peersManager peers.Manager headPeer *ephemeralPeer tailPeer *ephemeralPeer @@ -49,12 +50,12 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager { +func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager { return &EphemeralManager{ - store: store, - accountManager: accountManager, + store: store, + peersManager: peersManager, - lifeTime: ephemeralLifeTime, + lifeTime: ephemeral.EphemeralLifeTime, cleanupWindow: cleanupWindow, } } @@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee } // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer -// is inactive it will be deleted after the ephemeralLifeTime period. +// is inactive it will be deleted after the EphemeralLifeTime period. func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return @@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() - bufferAccountCall := make(map[string]struct{}) - + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { - log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) + peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) + } + + for accountID, peerIDs := range peerIDsPerAccount { + log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) - } else { - bufferAccountCall[p.accountID] = struct{}{} } } - for accountID := range bufferAccountCall { - e.accountManager.BufferUpdateAccountPeers(ctx, accountID) - } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { diff --git a/management/server/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go similarity index 69% rename from management/server/peers/ephemeral/manager/ephemeral_test.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc7525c29..9d3ed246a 100644 --- a/management/server/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -7,10 +7,13 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { @@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers (except the connected one) + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + 1 @@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for the one disconnected peer + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 @@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { testLifeTime = 1 * time.Second testCleanupWindow = 100 * time.Millisecond ) + + t.Cleanup(func() { + timeNow = time.Now + }) + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + mockStore := &MockStore{} + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + + wg := &sync.WaitGroup{} + wg.Add(ephemeralPeers) mockAM := &MockAccountManager{ store: mockStore, + wg: wg, } - mockAM.wg = &sync.WaitGroup{} - mockAM.wg.Add(ephemeralPeers) - mgr := NewEphemeralManager(mockStore, mockAM) + + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) + + // Set up expectation that DeletePeers will be called once with all peer IDs + peersManager.EXPECT(). + DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + // Simulate the actual deletion behavior + for _, peerID := range peerIDs { + err := mockAM.DeletePeer(ctx, accountID, peerID, userID) + if err != nil { + return err + } + } + mockAM.BufferUpdateAccountPeers(ctx, accountID) + return nil + }). + Times(1) + + mgr := NewEphemeralManager(mockStore, peersManager) mgr.lifeTime = testLifeTime mgr.cleanupWindow = testCleanupWindow - account := newAccountWithId(context.Background(), "account", "", "", false) - mockStore.account = account + // Add peers and disconnect them at slightly different times (within cleanup window) for i := range ephemeralPeers { p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} mockStore.account.Peers[p.ID] = p - time.Sleep(testCleanupWindow / ephemeralPeers) mgr.OnPeerDisconnected(context.Background(), p) + startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2)) } - mockAM.wg.Wait() + + // Advance time past the lifetime to trigger cleanup + startTime = startTime.Add(testLifeTime + testCleanupWindow) + + // Wait for all deletions to complete + wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go new file mode 100644 index 000000000..e82f19e63 --- /dev/null +++ b/management/internals/modules/peers/manager.go @@ -0,0 +1,162 @@ +package peers + +//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "fmt" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type Manager interface { + GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) + GetPeerAccountID(ctx context.Context, peerID string) (string, error) + GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) + DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error + SetNetworkMapController(networkMapController network_map.Controller) + SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) + SetAccountManager(accountManager account.Manager) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + integratedPeerValidator integrated_validator.IntegratedValidator + accountManager account.Manager + + networkMapController network_map.Controller +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) { + m.networkMapController = networkMapController +} + +func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.integratedPeerValidator = integratedPeerValidator +} + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + m.accountManager = accountManager +} + +func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} + +func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) + } + + return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { + return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) +} + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} + +func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + dnsDomain := m.networkMapController.GetDNSDomain(settings) + + for _, peerID := range peerIDs { + var eventsToStore []func() + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + + if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + return nil + } + + if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil { + return fmt.Errorf("failed to remove peer %s from groups", peerID) + } + + if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil { + return err + } + + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + + if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + + return nil + }) + if err != nil { + return err + } + for _, event := range eventsToStore { + event() + } + } + + return nil +} diff --git a/management/server/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go similarity index 55% rename from management/server/peers/manager_mock.go rename to management/internals/modules/peers/manager_mock.go index 994f8346b..2e3651e88 100644 --- a/management/server/peers/manager_mock.go +++ b/management/internals/modules/peers/manager_mock.go @@ -9,6 +9,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map" + account "github.com/netbirdio/netbird/management/server/account" + integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" peer "github.com/netbirdio/netbird/management/server/peer" ) @@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// DeletePeers mocks base method. +func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeers indicates an expected call of DeletePeers. +func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected) +} + // GetAllPeers mocks base method. func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { m.ctrl.T.Helper() @@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) } + +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + +// SetIntegratedPeerValidator mocks base method. +func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator) +} + +// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator. +func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator) +} + +// SetNetworkMapController mocks base method. +func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetworkMapController", networkMapController) +} + +// SetNetworkMapController indicates an expected call of SetNetworkMapController. +func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eadd16c2d..37788e80e 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { - store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false) + store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) if err != nil { log.Fatalf("failed to create store: %v", err) } @@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store { log.Fatalf("failed to initialize integration metrics: %v", err) } - eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics) + eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } - if s.config.DataStoreEncryptionKey != key { - log.WithContext(context.Background()).Infof("update config with activity store key") - s.config.DataStoreEncryptionKey = key - err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config) + if s.Config.DataStoreEncryptionKey != key { + log.WithContext(context.Background()).Infof("update Config with activity store key") + s.Config.DataStoreEncryptionKey = key + err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config) if err != nil { - log.Fatalf("failed to update config with activity store: %v", err) + log.Fatalf("failed to update Config with activity store: %v", err) } } @@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { - trustedPeers := s.config.ReverseProxy.TrustedPeers + trustedPeers := s.Config.ReverseProxy.TrustedPeers defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") trustedPeers = defaultTrustedPeers } - trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies - trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount + trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies + trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") @@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server { grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } - if s.config.HttpConfig.LetsEncryptDomain != "" { - certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { log.Fatalf("failed to create certificate manager: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.Fatalf("cannot load TLS credentials: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 38ec6fde6..3442c7646 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -9,17 +9,17 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { - return Create(s, func() *update_channel.PeersUpdateManager { + return Create(s, func() network_map.PeersUpdateManager { return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { - return Create(s, func() *grpc.TimeBasedAuthSecretsManager { - return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() grpc.SecretsManager { + return Create(s, func() grpc.SecretsManager { + secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager()) + if err != nil { + log.Fatalf("failed to create secrets manager: %v", err) + } + return secretsManager }) } func (s *BaseServer) AuthManager() auth.Manager { return Create(s, func() auth.Manager { return auth.NewManager(s.Store(), - s.config.HttpConfig.AuthIssuer, - s.config.HttpConfig.AuthAudience, - s.config.HttpConfig.AuthKeysLocation, - s.config.HttpConfig.AuthUserIDClaim, - s.config.GetAuthAudiences(), - s.config.HttpConfig.IdpSignKeyRefreshEnabled) + s.Config.HttpConfig.AuthIssuer, + s.Config.HttpConfig.AuthAudience, + s.Config.HttpConfig.AuthKeysLocation, + s.Config.HttpConfig.AuthUserIDClaim, + s.Config.GetAuthAudiences(), + s.Config.HttpConfig.IdpSignKeyRefreshEnabled) }) } func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.AccountManager()) + return manager.NewEphemeralManager(s.Store(), s.PeersManager()) }) } func (s *BaseServer) NetworkMapController() network_map.Controller { - return Create(s, func() *nmapcontroller.Controller { - return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config) + return Create(s, func() network_map.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config) }) } @@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) }) } + +func (s *BaseServer) DNSDomain() string { + return s.dnsDomain +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 18a8427be..91ce50a79 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/geolocation" @@ -14,7 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/users" @@ -22,12 +23,12 @@ import ( func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { return Create(s, func() geolocation.Geolocation { - geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate) + geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { log.Fatalf("could not initialize geolocation service: %v", err) } - log.Infof("geolocation service has been initialized from %s", s.config.Datadir) + log.Infof("geolocation service has been initialized from %s", s.Config.Datadir) return geo }) @@ -60,20 +61,22 @@ func (s *BaseServer) SettingsManager() settings.Manager { func (s *BaseServer) PeersManager() peers.Manager { return Create(s, func() peers.Manager { - return peers.NewManager(s.Store(), s.PermissionsManager()) + manager := peers.NewManager(s.Store(), s.PermissionsManager()) + s.AfterInit(func(s *BaseServer) { + manager.SetNetworkMapController(s.NetworkMapController()) + manager.SetIntegratedPeerValidator(s.IntegratedValidator()) + manager.SetAccountManager(s.AccountManager()) + }) + return manager }) } func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } - - s.AfterInit(func(s *BaseServer) { - accountManager.SetEphemeralManager(s.EphemeralManager()) - }) return accountManager }) } @@ -82,8 +85,8 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - if s.config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics()) + if s.Config.IdpManagerConfig != nil { + idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { log.Fatalf("failed to create IDP manager: %v", err) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ab1c2ebe7..a1b144dac 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -41,10 +41,10 @@ type Server interface { } // Server holds the HTTP BaseServer instance. -// Add any additional fields you need, such as database connections, config, etc. +// Add any additional fields you need, such as database connections, Config, etc. type BaseServer struct { - // config holds the server configuration - config *nbconfig.Config + // Config holds the server configuration + Config *nbconfig.Config // container of dependencies, each dependency is identified by a unique string. container map[string]any // AfterInit is a function that will be called after the server is initialized @@ -70,7 +70,7 @@ type BaseServer struct { // NewServer initializes and configures a new Server instance func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer { return &BaseServer{ - config: config, + Config: config, container: make(map[string]any), dnsDomain: dnsDomain, mgmtSingleAccModeDomain: mgmtSingleAccModeDomain, @@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error { var tlsConfig *tls.Config tlsEnabled := false - if s.config.HttpConfig.LetsEncryptDomain != "" { - s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) } tlsEnabled = true - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err) return err @@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error { if !s.disableMetrics { idpManager := "disabled" - if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" { - idpManager = s.config.IdpManagerConfig.ManagerType + if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" { + idpManager = s.Config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager) go metricsWorker.Run(srvCtx) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 0aadadf84..62dc215d8 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,15 +54,12 @@ const ( type Server struct { accountManager account.Manager settingsManager settings.Manager - wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager network_map.PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager ephemeral.Manager - peerLocks sync.Map - authManager auth.Manager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + peerLocks sync.Map + authManager auth.Manager logBlockedPeers bool blockPeersWithSameConfig bool @@ -82,23 +78,16 @@ func NewServer( config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, ) (*Server, error) { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, err - } - if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams - err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(peersUpdateManager.CountStreams()) + err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(networkMapController.CountStreams()) }) if err != nil { return nil, err @@ -120,16 +109,12 @@ func NewServer( } return &Server{ - wgKey: key, - // peerKey -> event channel - peersUpdateManager: peersUpdateManager, accountManager: accountManager, settingsManager: settingsManager, config: config, secretsManager: secretsManager, authManager: authManager, appMetrics: appMetrics, - ephemeralManager: ephemeralManager, logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, @@ -163,8 +148,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser nanos := int32(now.Nanosecond()) expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err) + return nil, errors.New("failed to get wireguard key") + } + return &proto.ServerKeyResponse{ - Key: s.wgKey.PublicKey().String(), + Key: key.PublicKey().String(), ExpiresAt: expiresAt, }, nil } @@ -269,9 +260,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - - s.ephemeralManager.OnPeerConnected(ctx, peer) + updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) + return err + } s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) @@ -323,13 +318,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + key, err := s.secretsManager.GetWGKey() + if err != nil { + s.cancelPeerRoutines(ctx, accountID, peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) if err != nil { @@ -348,9 +349,8 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } - s.peersUpdateManager.CloseChannel(ctx, peer.ID) + s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - s.ephemeralManager.OnPeerDisconnected(ctx, peer) log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } @@ -504,7 +504,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, parsed) if err != nil { return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") } @@ -601,12 +606,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) - // if the login request contains setup key then it is a registration request - if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(ctx, peer) - log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) - } - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) @@ -615,14 +614,20 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -715,14 +720,19 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return status.Errorf(codes.Internal, "failed getting server key") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp) if err != nil { return status.Errorf(codes.Internal, "error handling request") } sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) @@ -752,7 +762,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -782,13 +797,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr }, } - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -810,7 +825,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -838,13 +858,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go index 9867b38e3..d3a12e986 100644 --- a/management/internals/shared/grpc/server_test.go +++ b/management/internals/shared/grpc/server_test.go @@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { mgmtServer := &Server{ - wgKey: testingServerKey, + secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey}, config: &config.Config{ DeviceAuthorizationFlow: testCase.inputFlow, }, } message := &mgmtProto.DeviceAuthorizationFlowRequest{} + key, err := mgmtServer.secretsManager.GetWGKey() + require.NoError(t, err, "should be able to get server key") - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message) require.NoError(t, err, "should be able to encrypt message") resp, err := mgmtServer.GetDeviceAuthorizationFlow( @@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { if testCase.expectedComparisonFunc != nil { flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp) require.NoError(t, err, "should be able to decrypt") testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index e9770db41..0f893ae3a 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -10,6 +10,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" @@ -29,6 +30,7 @@ type SecretsManager interface { GenerateRelayToken() (*Token, error) SetupRefresh(ctx context.Context, accountID, peerKey string) CancelRefresh(peerKey string) + GetWGKey() (wgtypes.Key, error) } // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server @@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct { groupsManager groups.Manager turnCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{} + wgKey wgtypes.Key } type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager relayCancelMap: make(map[string]chan struct{}), settingsManager: settingsManager, groupsManager: groupsManager, + wgKey: key, } if turnCfg != nil { @@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager } } - return mgr + return mgr, nil +} + +// GetWGKey returns WireGuard private key used to generate peer keys +func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) { + return m.wgKey, nil } // GenerateTurnToken generates new time-based secret credentials for TURN diff --git a/management/internals/shared/grpc/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go index 06d28d05b..98eb66fb5 100644 --- a/management/internals/shared/grpc/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) turnCredentials, err := tested.GenerateTurnToken() require.NoError(t, err) @@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) tested.SetupRefresh(context.Background(), "someAccountID", peer) if _, ok := tested.turnCancelMap[peer]; !ok { diff --git a/management/server/account.go b/management/server/account.go index 716d5ab5d..dac040db0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -37,7 +37,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -77,7 +76,6 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation - ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -238,7 +236,7 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) if err != nil { return nil, fmt.Errorf("getting cache store: %s", err) } @@ -263,10 +261,6 @@ func BuildManager( return am, nil } -func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { - am.ephemeralManager = em -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -2076,7 +2070,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.networkMapController.OnPeerUpdated(peer.AccountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 9b3902d87..b5921ec7a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -13,7 +13,6 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -124,5 +123,4 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) - SetEphemeralManager(em ephemeral.Manager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 340e8db18..8569f1b2f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,6 +25,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" @@ -2959,8 +2961,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) - manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } @@ -3371,7 +3373,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Run("memory cache", func(t *testing.T) { t.Run("should always return true", func(t *testing.T) { - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) cold, err := manager.isCacheCold(context.Background(), cacheStore) @@ -3386,7 +3388,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) t.Run("should return true when no account exists", func(t *testing.T) { diff --git a/management/server/cache/idp.go b/management/server/cache/idp.go index 1b31ff82a..19dfc0f38 100644 --- a/management/server/cache/idp.go +++ b/management/server/cache/idp.go @@ -18,6 +18,7 @@ const ( DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days DefaultIDPCacheCleanupInterval = 30 * time.Minute + DefaultIDPCacheOpenConn = 100 ) // UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects. diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index 3fcfbb11a..0e8061e94 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } - cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) + cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn) if err != nil { t.Fatalf("couldn't create cache store: %s", err) } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 1c141a180..54b0242de 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -3,6 +3,7 @@ package cache import ( "context" "fmt" + "math" "os" "time" @@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. -func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) { +func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { redisAddr := os.Getenv(RedisStoreEnvVar) if redisAddr != "" { - return getRedisStore(ctx, redisAddr) + return getRedisStore(ctx, redisAddr, maxConn) } goc := gocache.New(maxTimeout, cleanupInterval) return gocache_store.NewGoCache(goc), nil } -func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) { +func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { return nil, fmt.Errorf("parsing redis cache url: %s", err) } - options.MaxIdleConns = 6 - options.MinIdleConns = 3 - options.MaxActiveConns = 100 + options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns + options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns + options.MaxActiveConns = maxConn + options.ConnMaxIdleTime = 30 * time.Minute + options.ConnMaxLifetime = 0 + options.PoolTimeout = 10 * time.Second redisClient := redis.NewClient(options) subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() diff --git a/management/server/cache/store_test.go b/management/server/cache/store_test.go index f49dd6bbd..1b64fd70d 100644 --- a/management/server/cache/store_test.go +++ b/management/server/cache/store_test.go @@ -15,7 +15,7 @@ import ( ) func TestMemoryStore(t *testing.T) { - memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create memory store: %s", err) } @@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) { func TestRedisStoreConnectionFailure(t *testing.T) { t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379") - _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond) + _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100) if err == nil { t.Fatal("getting redis cache store should return error") } @@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) { } t.Setenv(cache.RedisStoreEnvVar, redisURL) - redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create redis store: %s", err) } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 99b09566a..b5e3f2b99 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -12,6 +12,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" @@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c1a8c5885..7cf0b5765 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" + nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -39,7 +40,6 @@ import ( nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - nbpeers "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/telemetry" ) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c4c5ae165..f531f0cdb 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -45,19 +45,6 @@ func NewHandler(accountManager account.Manager, networkMapController network_map } } -func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { - peerToReturn := peer.Copy() - if peer.Status.Connected { - // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected - // This may happen after server restart when not all peers are yet connected - if !h.networkMapController.IsConnected(peer.ID) { - peerToReturn.Status.Connected = false - } - } - - return peerToReturn, nil -} - func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { @@ -65,11 +52,6 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(ctx, err, w) - return - } settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) if err != nil { util.WriteError(ctx, err, w) @@ -91,7 +73,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, _, valid := validPeers[peer.ID] reason := invalidPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -237,13 +219,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) + respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index ddf2e2a70..55e779ff0 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -109,14 +109,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { GetDNSDomain(gomock.Any()). Return("domain"). AnyTimes() - networkMapController.EXPECT(). - IsConnected(noUpdateChannelTestPeerID). - Return(false). - AnyTimes() - networkMapController.EXPECT(). - IsConnected(gomock.Any()). - Return(true). - AnyTimes() return &Handler{ accountManager: &mock_server.MockAccountManager{ @@ -269,14 +261,6 @@ func TestGetPeers(t *testing.T) { expectedArray: false, expectedPeer: peer, }, - { - name: "GetPeer with no update channel", - requestType: http.MethodGet, - requestPath: "/api/peers/" + peer1.ID, - expectedStatus: http.StatusOK, - expectedArray: false, - expectedPeer: expectedPeer1, - }, { name: "PutPeer", requestType: http.MethodPut, @@ -336,8 +320,6 @@ func TestGetPeers(t *testing.T) { for _, peer := range respBody { if peer.Id == testPeerID { got = peer - } else { - assert.Equal(t, peer.Connected, false) } } @@ -351,14 +333,14 @@ func TestGetPeers(t *testing.T) { t.Log(got) - assert.Equal(t, got.Name, tc.expectedPeer.Name) - assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) - assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) - assert.Equal(t, got.Os, "OS core") - assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) - assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) - assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) - assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber) + assert.Equal(t, tc.expectedPeer.Name, got.Name) + assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version) + assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip) + assert.Equal(t, "OS core", got.Os) + assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled) + assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled) + assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected) + assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber) }) } } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e292a7d6c..e8513feb5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" @@ -28,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -72,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 42311d944..42f192c0a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -24,13 +24,14 @@ import ( "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -363,7 +364,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) @@ -372,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } - ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 2350b225b..648201d4e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -22,13 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -205,7 +206,7 @@ func startServer( ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) accountManager, err := server.BuildManager( context.Background(), @@ -228,15 +229,16 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatalf("failed creating secrets manager: %v", err) + } mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - updateManager, secretsManager, nil, - &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, networkMapController, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0178e51f5..928098dbe 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -976,11 +975,6 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth a return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } -// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface -func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { - // Mock implementation - does nothing -} - func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index a4574c978..e3dd8b0b8 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -13,6 +13,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -792,7 +794,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer.go b/management/server/peer.go index cd9fbe4c8..f2de05f15 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -136,7 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil @@ -309,7 +312,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } return peer, nil } @@ -365,13 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } return nil @@ -583,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually - am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -645,7 +641,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } @@ -729,7 +725,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) + } } return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) @@ -857,7 +856,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } } p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2d09f5200..752563299 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -28,6 +28,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" @@ -1058,6 +1060,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel + assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } @@ -1290,7 +1293,7 @@ func Test_RegisterPeerByUser(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1375,7 +1378,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1528,7 +1531,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1608,7 +1611,7 @@ func Test_LoginPeer(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go deleted file mode 100644 index cb135f4ac..000000000 --- a/management/server/peers/manager.go +++ /dev/null @@ -1,68 +0,0 @@ -package peers - -//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -type Manager interface { - GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) - GetPeerAccountID(ctx context.Context, peerID string) (string, error) - GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) -} - -type managerImpl struct { - store store.Store - permissionsManager permissions.Manager -} - -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { - return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - } -} - -func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) -} - -func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - } - - return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") -} - -func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { - return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) -} - -func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) -} diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c8b636bc..a413d545b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -1291,7 +1293,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/management/server/user.go b/management/server/user.go index cefc4d1a5..ca02f91e6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -263,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) - } - return nil } @@ -998,14 +994,17 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.networkMapController.DisconnectPeers(ctx, peerIDs) + am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs) } return nil } @@ -1051,7 +1050,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } var allErrors error - var updateAccountPeers bool for _, targetUserID := range targetUserIDs { if initiatorUserID == targetUserID { @@ -1082,19 +1080,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { allErrors = errors.Join(allErrors, err) continue } - - if userHadPeers { - updateAccountPeers = true - } - } - - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) } return allErrors @@ -1152,15 +1142,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + var peerIDs []string for _, peer := range userPeers { - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) - } + peerIDs = append(peerIDs, peer.ID) + } + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { + log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) } for _, addPeerRemovedEvent := range addPeerRemovedEvents { diff --git a/management/server/user_test.go b/management/server/user_test.go index 5ce15621e..0d778cfa2 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,8 +8,10 @@ import ( "time" "github.com/google/go-cmp/cmp" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) { permissionsManager: permissionsManager, } - cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) require.NoError(t, err) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs) @@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MockIntegratedValidator{}, permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { permissionsManager: permissionsManager, } - cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) assert.NoError(t, err) am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) @@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("deleting user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 9e08317f6..9fbe70948 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -21,6 +21,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" @@ -31,8 +33,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) @@ -125,8 +125,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } From 10e9cf8c62ec7acef474cf96fa65283f2e20941a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:13:01 +0100 Subject: [PATCH 083/118] [management] update management integrations (#4895) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 91e587c32..870118c88 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 98d395ad1..96a303f79 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= From a293f760af653afb0afe7635680c0ff5a7527128 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 2 Dec 2025 16:30:15 +0100 Subject: [PATCH 084/118] [client] Add conditional peer removal logic during shutdown (#4897) --- client/internal/engine.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0ff1006cd..84bb8beea 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -292,6 +292,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { + log.Info("removing peers before dns") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } + } if err := e.stopSSHServer(); err != nil { log.Warnf("failed to stop SSH server: %v", err) } @@ -310,6 +316,13 @@ func (e *Engine) Stop() error { e.stopDNSForwarder() + if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { + log.Info("removing peers before routes") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } + } + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } @@ -317,13 +330,16 @@ func (e *Engine) Stop() error { if e.srWatcher != nil { e.srWatcher.Close() } - + log.Info("cleaning up status recorder states") e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) - if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { + log.Info("removing peers after dns and routes") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } } if e.cancel != nil { From a232cf614c33a58b9b7c3c522b855273696889c8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:31:59 +0100 Subject: [PATCH 085/118] [management] record pat usage metrics (#4888) --- management/server/http/handler.go | 1 + .../server/http/middleware/auth_middleware.go | 17 ++++ .../http/middleware/auth_middleware_test.go | 7 ++ .../http/middleware/pat_usage_tracker.go | 87 +++++++++++++++++++ 4 files changed, 112 insertions(+) create mode 100644 management/server/http/middleware/pat_usage_tracker.go diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7cf0b5765..b7c6c113c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -105,6 +105,7 @@ func NewAPIHandler( accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, rateLimitingConfig, + appMetrics.GetMeter(), ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 9439165a4..ffd7e0b93 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -9,6 +9,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -31,6 +32,7 @@ type AuthMiddleware struct { getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc rateLimiter *APIRateLimiter + patUsageTracker *PATUsageTracker } // NewAuthMiddleware instance constructor @@ -40,18 +42,29 @@ func NewAuthMiddleware( syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, rateLimiterConfig *RateLimiterConfig, + meter metric.Meter, ) *AuthMiddleware { var rateLimiter *APIRateLimiter if rateLimiterConfig != nil { rateLimiter = NewAPIRateLimiter(rateLimiterConfig) } + var patUsageTracker *PATUsageTracker + if meter != nil { + var err error + patUsageTracker, err = NewPATUsageTracker(context.Background(), meter) + if err != nil { + log.Errorf("Failed to create PAT usage tracker: %s", err) + } + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, rateLimiter: rateLimiter, + patUsageTracker: patUsageTracker, } } @@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] return r, fmt.Errorf("error extracting token: %w", err) } + if m.patUsageTracker != nil { + m.patUsageTracker.IncrementUsage(token) + } + if m.rateLimiter != nil { if !m.rateLimiter.Allow(token) { return r, status.Errorf(status.TooManyRequests, "too many requests") diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 7badc03e4..ba4d16796 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/pat_usage_tracker.go b/management/server/http/middleware/pat_usage_tracker.go new file mode 100644 index 000000000..331c288e7 --- /dev/null +++ b/management/server/http/middleware/pat_usage_tracker.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "context" + "maps" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" +) + +// PATUsageTracker tracks PAT usage metrics +type PATUsageTracker struct { + usageCounters map[string]int64 + mu sync.Mutex + stopChan chan struct{} + ctx context.Context + histogram metric.Int64Histogram +} + +// NewPATUsageTracker creates a new PAT usage tracker with metrics +func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) { + histogram, err := meter.Int64Histogram( + "management.pat.usage_distribution", + metric.WithUnit("1"), + metric.WithDescription("Distribution of PAT token usage counts per minute"), + ) + if err != nil { + return nil, err + } + + tracker := &PATUsageTracker{ + usageCounters: make(map[string]int64), + stopChan: make(chan struct{}), + ctx: ctx, + histogram: histogram, + } + + go tracker.reportLoop() + + return tracker, nil +} + +// IncrementUsage increments the usage counter for a given token +func (t *PATUsageTracker) IncrementUsage(token string) { + t.mu.Lock() + defer t.mu.Unlock() + t.usageCounters[token]++ +} + +// reportLoop reports the usage buckets every minute +func (t *PATUsageTracker) reportLoop() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.reportUsageBuckets() + case <-t.stopChan: + return + } + } +} + +// reportUsageBuckets reports all token usage counts and resets counters +func (t *PATUsageTracker) reportUsageBuckets() { + t.mu.Lock() + snapshot := maps.Clone(t.usageCounters) + + clear(t.usageCounters) + t.mu.Unlock() + + totalTokens := len(snapshot) + if totalTokens > 0 { + for _, count := range snapshot { + t.histogram.Record(t.ctx, count) + } + log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens) + } +} + +// Stop stops the reporting goroutine +func (t *PATUsageTracker) Stop() { + close(t.stopChan) +} From e87b4ace11988e9dec38d5f8ccf5c4334a9fac05 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 3 Dec 2025 11:53:39 +0100 Subject: [PATCH 086/118] [client] Add sleep state tracking to handle wakeup/sleep events reliably (#4894) Adds a new NotifyOSLifecycle RPC and server handler to centralize OS sleep/wake handling, introduces Server.sleepTriggeredDown for coordination, updates client UI to call the new RPC, and adjusts the internal sleep event enum zero-value semantics. --- client/internal/sleep/service.go | 5 +- client/proto/daemon.pb.go | 1019 +++++++++++++++++------------- client/proto/daemon.proto | 19 +- client/proto/daemon_grpc.pb.go | 40 +- client/server/lifecycle.go | 77 +++ client/server/lifecycle_test.go | 219 +++++++ client/server/server.go | 3 + client/ui/client_ui.go | 26 +- 8 files changed, 954 insertions(+), 454 deletions(-) create mode 100644 client/server/lifecycle.go create mode 100644 client/server/lifecycle_test.go diff --git a/client/internal/sleep/service.go b/client/internal/sleep/service.go index 35fc933c0..196a33f52 100644 --- a/client/internal/sleep/service.go +++ b/client/internal/sleep/service.go @@ -1,8 +1,9 @@ package sleep var ( - EventTypeSleep EventType = 0 - EventTypeWakeUp EventType = 1 + EventTypeUnknown EventType = 0 + EventTypeSleep EventType = 1 + EventTypeWakeUp EventType = 2 ) type EventType int diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 6f8255615..28e8b2d4e 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v3.21.12 +// protoc v6.32.1 // source: daemon.proto package proto @@ -88,6 +88,56 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +// avoid collision with loglevel enum +type OSLifecycleRequest_CycleType int32 + +const ( + OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0 + OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1 + OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2 +) + +// Enum value maps for OSLifecycleRequest_CycleType. +var ( + OSLifecycleRequest_CycleType_name = map[int32]string{ + 0: "UNKNOWN", + 1: "SLEEP", + 2: "WAKEUP", + } + OSLifecycleRequest_CycleType_value = map[string]int32{ + "UNKNOWN": 0, + "SLEEP": 1, + "WAKEUP": 2, + } +) + +func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType { + p := new(OSLifecycleRequest_CycleType) + *p = x + return p +} + +func (x OSLifecycleRequest_CycleType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { + return file_daemon_proto_enumTypes[1].Descriptor() +} + +func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { + return &file_daemon_proto_enumTypes[1] +} + +func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead. +func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1, 0} +} + type SystemEvent_Severity int32 const ( @@ -124,11 +174,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[1].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[1] + return &file_daemon_proto_enumTypes[2] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -137,7 +187,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51, 0} + return file_daemon_proto_rawDescGZIP(), []int{53, 0} } type SystemEvent_Category int32 @@ -179,11 +229,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -192,7 +242,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51, 1} + return file_daemon_proto_rawDescGZIP(), []int{53, 1} } type EmptyRequest struct { @@ -231,6 +281,86 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +type OSLifecycleRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleRequest) Reset() { + *x = OSLifecycleRequest{} + mi := &file_daemon_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleRequest) ProtoMessage() {} + +func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead. +func (*OSLifecycleRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1} +} + +func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType { + if x != nil { + return x.Type + } + return OSLifecycleRequest_UNKNOWN +} + +type OSLifecycleResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleResponse) Reset() { + *x = OSLifecycleResponse{} + mi := &file_daemon_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleResponse) ProtoMessage() {} + +func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead. +func (*OSLifecycleResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{2} +} + type LoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // setupKey netbird setup key. @@ -293,7 +423,7 @@ type LoginRequest struct { func (x *LoginRequest) Reset() { *x = LoginRequest{} - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -305,7 +435,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -318,7 +448,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1} + return file_daemon_proto_rawDescGZIP(), []int{3} } func (x *LoginRequest) GetSetupKey() string { @@ -607,7 +737,7 @@ type LoginResponse struct { func (x *LoginResponse) Reset() { *x = LoginResponse{} - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -619,7 +749,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -632,7 +762,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{2} + return file_daemon_proto_rawDescGZIP(), []int{4} } func (x *LoginResponse) GetNeedsSSOLogin() bool { @@ -673,7 +803,7 @@ type WaitSSOLoginRequest struct { func (x *WaitSSOLoginRequest) Reset() { *x = WaitSSOLoginRequest{} - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -685,7 +815,7 @@ func (x *WaitSSOLoginRequest) String() string { func (*WaitSSOLoginRequest) ProtoMessage() {} func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -698,7 +828,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead. func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{3} + return file_daemon_proto_rawDescGZIP(), []int{5} } func (x *WaitSSOLoginRequest) GetUserCode() string { @@ -724,7 +854,7 @@ type WaitSSOLoginResponse struct { func (x *WaitSSOLoginResponse) Reset() { *x = WaitSSOLoginResponse{} - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -736,7 +866,7 @@ func (x *WaitSSOLoginResponse) String() string { func (*WaitSSOLoginResponse) ProtoMessage() {} func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -749,7 +879,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead. func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{4} + return file_daemon_proto_rawDescGZIP(), []int{6} } func (x *WaitSSOLoginResponse) GetEmail() string { @@ -769,7 +899,7 @@ type UpRequest struct { func (x *UpRequest) Reset() { *x = UpRequest{} - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -781,7 +911,7 @@ func (x *UpRequest) String() string { func (*UpRequest) ProtoMessage() {} func (x *UpRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -794,7 +924,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpRequest.ProtoReflect.Descriptor instead. func (*UpRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{5} + return file_daemon_proto_rawDescGZIP(), []int{7} } func (x *UpRequest) GetProfileName() string { @@ -819,7 +949,7 @@ type UpResponse struct { func (x *UpResponse) Reset() { *x = UpResponse{} - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -831,7 +961,7 @@ func (x *UpResponse) String() string { func (*UpResponse) ProtoMessage() {} func (x *UpResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -844,7 +974,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpResponse.ProtoReflect.Descriptor instead. func (*UpResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{6} + return file_daemon_proto_rawDescGZIP(), []int{8} } type StatusRequest struct { @@ -859,7 +989,7 @@ type StatusRequest struct { func (x *StatusRequest) Reset() { *x = StatusRequest{} - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -871,7 +1001,7 @@ func (x *StatusRequest) String() string { func (*StatusRequest) ProtoMessage() {} func (x *StatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -884,7 +1014,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. func (*StatusRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{7} + return file_daemon_proto_rawDescGZIP(), []int{9} } func (x *StatusRequest) GetGetFullPeerStatus() bool { @@ -921,7 +1051,7 @@ type StatusResponse struct { func (x *StatusResponse) Reset() { *x = StatusResponse{} - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -933,7 +1063,7 @@ func (x *StatusResponse) String() string { func (*StatusResponse) ProtoMessage() {} func (x *StatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -946,7 +1076,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. func (*StatusResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{8} + return file_daemon_proto_rawDescGZIP(), []int{10} } func (x *StatusResponse) GetStatus() string { @@ -978,7 +1108,7 @@ type DownRequest struct { func (x *DownRequest) Reset() { *x = DownRequest{} - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -990,7 +1120,7 @@ func (x *DownRequest) String() string { func (*DownRequest) ProtoMessage() {} func (x *DownRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1003,7 +1133,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DownRequest.ProtoReflect.Descriptor instead. func (*DownRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{9} + return file_daemon_proto_rawDescGZIP(), []int{11} } type DownResponse struct { @@ -1014,7 +1144,7 @@ type DownResponse struct { func (x *DownResponse) Reset() { *x = DownResponse{} - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1026,7 +1156,7 @@ func (x *DownResponse) String() string { func (*DownResponse) ProtoMessage() {} func (x *DownResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1039,7 +1169,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DownResponse.ProtoReflect.Descriptor instead. func (*DownResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{10} + return file_daemon_proto_rawDescGZIP(), []int{12} } type GetConfigRequest struct { @@ -1052,7 +1182,7 @@ type GetConfigRequest struct { func (x *GetConfigRequest) Reset() { *x = GetConfigRequest{} - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1064,7 +1194,7 @@ func (x *GetConfigRequest) String() string { func (*GetConfigRequest) ProtoMessage() {} func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1077,7 +1207,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead. func (*GetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{11} + return file_daemon_proto_rawDescGZIP(), []int{13} } func (x *GetConfigRequest) GetProfileName() string { @@ -1133,7 +1263,7 @@ type GetConfigResponse struct { func (x *GetConfigResponse) Reset() { *x = GetConfigResponse{} - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1145,7 +1275,7 @@ func (x *GetConfigResponse) String() string { func (*GetConfigResponse) ProtoMessage() {} func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1158,7 +1288,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead. func (*GetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{12} + return file_daemon_proto_rawDescGZIP(), []int{14} } func (x *GetConfigResponse) GetManagementUrl() string { @@ -1370,7 +1500,7 @@ type PeerState struct { func (x *PeerState) Reset() { *x = PeerState{} - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1382,7 +1512,7 @@ func (x *PeerState) String() string { func (*PeerState) ProtoMessage() {} func (x *PeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1395,7 +1525,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerState.ProtoReflect.Descriptor instead. func (*PeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{13} + return file_daemon_proto_rawDescGZIP(), []int{15} } func (x *PeerState) GetIP() string { @@ -1540,7 +1670,7 @@ type LocalPeerState struct { func (x *LocalPeerState) Reset() { *x = LocalPeerState{} - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1552,7 +1682,7 @@ func (x *LocalPeerState) String() string { func (*LocalPeerState) ProtoMessage() {} func (x *LocalPeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1565,7 +1695,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead. func (*LocalPeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{14} + return file_daemon_proto_rawDescGZIP(), []int{16} } func (x *LocalPeerState) GetIP() string { @@ -1629,7 +1759,7 @@ type SignalState struct { func (x *SignalState) Reset() { *x = SignalState{} - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1641,7 +1771,7 @@ func (x *SignalState) String() string { func (*SignalState) ProtoMessage() {} func (x *SignalState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1654,7 +1784,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message { // Deprecated: Use SignalState.ProtoReflect.Descriptor instead. func (*SignalState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{15} + return file_daemon_proto_rawDescGZIP(), []int{17} } func (x *SignalState) GetURL() string { @@ -1690,7 +1820,7 @@ type ManagementState struct { func (x *ManagementState) Reset() { *x = ManagementState{} - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1702,7 +1832,7 @@ func (x *ManagementState) String() string { func (*ManagementState) ProtoMessage() {} func (x *ManagementState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1715,7 +1845,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message { // Deprecated: Use ManagementState.ProtoReflect.Descriptor instead. func (*ManagementState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{16} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *ManagementState) GetURL() string { @@ -1751,7 +1881,7 @@ type RelayState struct { func (x *RelayState) Reset() { *x = RelayState{} - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1763,7 +1893,7 @@ func (x *RelayState) String() string { func (*RelayState) ProtoMessage() {} func (x *RelayState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1776,7 +1906,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayState.ProtoReflect.Descriptor instead. func (*RelayState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{19} } func (x *RelayState) GetURI() string { @@ -1812,7 +1942,7 @@ type NSGroupState struct { func (x *NSGroupState) Reset() { *x = NSGroupState{} - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1824,7 +1954,7 @@ func (x *NSGroupState) String() string { func (*NSGroupState) ProtoMessage() {} func (x *NSGroupState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1837,7 +1967,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message { // Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. func (*NSGroupState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{18} + return file_daemon_proto_rawDescGZIP(), []int{20} } func (x *NSGroupState) GetServers() []string { @@ -1881,7 +2011,7 @@ type SSHSessionInfo struct { func (x *SSHSessionInfo) Reset() { *x = SSHSessionInfo{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1893,7 +2023,7 @@ func (x *SSHSessionInfo) String() string { func (*SSHSessionInfo) ProtoMessage() {} func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1906,7 +2036,7 @@ func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. func (*SSHSessionInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *SSHSessionInfo) GetUsername() string { @@ -1948,7 +2078,7 @@ type SSHServerState struct { func (x *SSHServerState) Reset() { *x = SSHServerState{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1960,7 +2090,7 @@ func (x *SSHServerState) String() string { func (*SSHServerState) ProtoMessage() {} func (x *SSHServerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1973,7 +2103,7 @@ func (x *SSHServerState) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. func (*SSHServerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{22} } func (x *SSHServerState) GetEnabled() bool { @@ -2009,7 +2139,7 @@ type FullStatus struct { func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2021,7 +2151,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2034,7 +2164,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -2116,7 +2246,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2128,7 +2258,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2141,7 +2271,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{24} } type ListNetworksResponse struct { @@ -2153,7 +2283,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2165,7 +2295,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2178,7 +2308,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{25} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -2199,7 +2329,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2211,7 +2341,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2224,7 +2354,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2256,7 +2386,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2268,7 +2398,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2281,7 +2411,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{27} } type IPList struct { @@ -2293,7 +2423,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2305,7 +2435,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2318,7 +2448,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *IPList) GetIps() []string { @@ -2341,7 +2471,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2353,7 +2483,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2366,7 +2496,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *Network) GetID() string { @@ -2418,7 +2548,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2430,7 +2560,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2443,7 +2573,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2500,7 +2630,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2512,7 +2642,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2525,7 +2655,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *ForwardingRule) GetProtocol() string { @@ -2572,7 +2702,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2584,7 +2714,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2597,7 +2727,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2621,7 +2751,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2633,7 +2763,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2646,7 +2776,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{33} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2695,7 +2825,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2707,7 +2837,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2720,7 +2850,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *DebugBundleResponse) GetPath() string { @@ -2752,7 +2882,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2764,7 +2894,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2777,7 +2907,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{35} } type GetLogLevelResponse struct { @@ -2789,7 +2919,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2801,7 +2931,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2814,7 +2944,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{36} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -2833,7 +2963,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2845,7 +2975,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2858,7 +2988,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -2876,7 +3006,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2888,7 +3018,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2901,7 +3031,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{38} } // State represents a daemon state entry @@ -2914,7 +3044,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2926,7 +3056,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2939,7 +3069,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *State) GetName() string { @@ -2958,7 +3088,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2970,7 +3100,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2983,7 +3113,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{40} } // ListStatesResponse contains a list of states @@ -2996,7 +3126,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3008,7 +3138,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3021,7 +3151,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *ListStatesResponse) GetStates() []*State { @@ -3042,7 +3172,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3054,7 +3184,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3067,7 +3197,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *CleanStateRequest) GetStateName() string { @@ -3094,7 +3224,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3106,7 +3236,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3119,7 +3249,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -3140,7 +3270,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3152,7 +3282,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3165,7 +3295,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *DeleteStateRequest) GetStateName() string { @@ -3192,7 +3322,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3204,7 +3334,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3217,7 +3347,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{45} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3236,7 +3366,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3248,7 +3378,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3261,7 +3391,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3279,7 +3409,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3291,7 +3421,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3304,7 +3434,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{47} } type TCPFlags struct { @@ -3321,7 +3451,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3333,7 +3463,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3346,7 +3476,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TCPFlags) GetSyn() bool { @@ -3408,7 +3538,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3420,7 +3550,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3433,7 +3563,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3511,7 +3641,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3523,7 +3653,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3536,7 +3666,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{50} } func (x *TraceStage) GetName() string { @@ -3577,7 +3707,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3589,7 +3719,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3602,7 +3732,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3627,7 +3757,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3639,7 +3769,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3652,7 +3782,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{52} } type SystemEvent struct { @@ -3670,7 +3800,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3682,7 +3812,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3695,7 +3825,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *SystemEvent) GetId() string { @@ -3755,7 +3885,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3767,7 +3897,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3780,7 +3910,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{54} } type GetEventsResponse struct { @@ -3792,7 +3922,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3804,7 +3934,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3817,7 +3947,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{55} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -3837,7 +3967,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3849,7 +3979,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3862,7 +3992,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -3887,7 +4017,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3899,7 +4029,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3912,7 +4042,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{57} } type SetConfigRequest struct { @@ -3960,7 +4090,7 @@ type SetConfigRequest struct { func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3972,7 +4102,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3985,7 +4115,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *SetConfigRequest) GetUsername() string { @@ -4234,7 +4364,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4246,7 +4376,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4259,7 +4389,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{59} } type AddProfileRequest struct { @@ -4272,7 +4402,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4284,7 +4414,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4297,7 +4427,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *AddProfileRequest) GetUsername() string { @@ -4322,7 +4452,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4334,7 +4464,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4347,7 +4477,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{61} } type RemoveProfileRequest struct { @@ -4360,7 +4490,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4372,7 +4502,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4385,7 +4515,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4410,7 +4540,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4422,7 +4552,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4435,7 +4565,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{63} } type ListProfilesRequest struct { @@ -4447,7 +4577,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4459,7 +4589,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4472,7 +4602,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *ListProfilesRequest) GetUsername() string { @@ -4491,7 +4621,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4503,7 +4633,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4516,7 +4646,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{65} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4536,7 +4666,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4548,7 +4678,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4561,7 +4691,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *Profile) GetName() string { @@ -4586,7 +4716,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4598,7 +4728,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4611,7 +4741,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{67} } type GetActiveProfileResponse struct { @@ -4624,7 +4754,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4636,7 +4766,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4649,7 +4779,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{68} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4676,7 +4806,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4688,7 +4818,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4701,7 +4831,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{69} } func (x *LogoutRequest) GetProfileName() string { @@ -4726,7 +4856,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4738,7 +4868,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4751,7 +4881,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{70} } type GetFeaturesRequest struct { @@ -4762,7 +4892,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4774,7 +4904,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[71] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4787,7 +4917,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{69} + return file_daemon_proto_rawDescGZIP(), []int{71} } type GetFeaturesResponse struct { @@ -4800,7 +4930,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4812,7 +4942,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[72] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4825,7 +4955,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{70} + return file_daemon_proto_rawDescGZIP(), []int{72} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -4853,7 +4983,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[73] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4865,7 +4995,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[73] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4878,7 +5008,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{71} + return file_daemon_proto_rawDescGZIP(), []int{73} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -4905,7 +5035,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[74] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4917,7 +5047,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[74] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4930,7 +5060,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{72} + return file_daemon_proto_rawDescGZIP(), []int{74} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -4972,7 +5102,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4984,7 +5114,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4997,7 +5127,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5030,7 +5160,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5042,7 +5172,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5055,7 +5185,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5120,7 +5250,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5132,7 +5262,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5145,7 +5275,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5177,7 +5307,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5189,7 +5319,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5202,7 +5332,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5236,7 +5366,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5248,7 +5378,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5261,7 +5391,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28, 0} + return file_daemon_proto_rawDescGZIP(), []int{30, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -5283,7 +5413,15 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xb6\x12\n" + + "\fEmptyRequest\"\x7f\n" + + "\x12OSLifecycleRequest\x128\n" + + "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" + + "\tCycleType\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\t\n" + + "\x05SLEEP\x10\x01\x12\n" + + "\n" + + "\x06WAKEUP\x10\x02\"\x15\n" + + "\x13OSLifecycleResponse\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -5764,7 +5902,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x8b\x12\n" + + "\x05TRACE\x10\a2\xdb\x12\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5799,7 +5937,8 @@ const file_daemon_proto_rawDesc = "" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + - "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5813,196 +5952,202 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 80) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel - (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 2: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 3: daemon.EmptyRequest - (*LoginRequest)(nil), // 4: daemon.LoginRequest - (*LoginResponse)(nil), // 5: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 6: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 7: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 8: daemon.UpRequest - (*UpResponse)(nil), // 9: daemon.UpResponse - (*StatusRequest)(nil), // 10: daemon.StatusRequest - (*StatusResponse)(nil), // 11: daemon.StatusResponse - (*DownRequest)(nil), // 12: daemon.DownRequest - (*DownResponse)(nil), // 13: daemon.DownResponse - (*GetConfigRequest)(nil), // 14: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 15: daemon.GetConfigResponse - (*PeerState)(nil), // 16: daemon.PeerState - (*LocalPeerState)(nil), // 17: daemon.LocalPeerState - (*SignalState)(nil), // 18: daemon.SignalState - (*ManagementState)(nil), // 19: daemon.ManagementState - (*RelayState)(nil), // 20: daemon.RelayState - (*NSGroupState)(nil), // 21: daemon.NSGroupState - (*SSHSessionInfo)(nil), // 22: daemon.SSHSessionInfo - (*SSHServerState)(nil), // 23: daemon.SSHServerState - (*FullStatus)(nil), // 24: daemon.FullStatus - (*ListNetworksRequest)(nil), // 25: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 26: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 27: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 28: daemon.SelectNetworksResponse - (*IPList)(nil), // 29: daemon.IPList - (*Network)(nil), // 30: daemon.Network - (*PortInfo)(nil), // 31: daemon.PortInfo - (*ForwardingRule)(nil), // 32: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 33: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 34: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 35: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 36: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 37: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 38: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 39: daemon.SetLogLevelResponse - (*State)(nil), // 40: daemon.State - (*ListStatesRequest)(nil), // 41: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 42: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 43: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 44: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 45: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 46: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 47: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 48: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 49: daemon.TCPFlags - (*TracePacketRequest)(nil), // 50: daemon.TracePacketRequest - (*TraceStage)(nil), // 51: daemon.TraceStage - (*TracePacketResponse)(nil), // 52: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 53: daemon.SubscribeRequest - (*SystemEvent)(nil), // 54: daemon.SystemEvent - (*GetEventsRequest)(nil), // 55: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 56: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 57: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 58: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 59: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 60: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 61: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 62: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 63: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 64: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 65: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 66: daemon.ListProfilesResponse - (*Profile)(nil), // 67: daemon.Profile - (*GetActiveProfileRequest)(nil), // 68: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 69: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 70: daemon.LogoutRequest - (*LogoutResponse)(nil), // 71: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 72: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 73: daemon.GetFeaturesResponse - (*GetPeerSSHHostKeyRequest)(nil), // 74: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 75: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 76: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 77: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 78: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 79: daemon.WaitJWTTokenResponse - nil, // 80: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 81: daemon.PortInfo.Range - nil, // 82: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 83: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 84: google.protobuf.Timestamp + (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType + (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 4: daemon.EmptyRequest + (*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest + (*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse + (*LoginRequest)(nil), // 7: daemon.LoginRequest + (*LoginResponse)(nil), // 8: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 11: daemon.UpRequest + (*UpResponse)(nil), // 12: daemon.UpResponse + (*StatusRequest)(nil), // 13: daemon.StatusRequest + (*StatusResponse)(nil), // 14: daemon.StatusResponse + (*DownRequest)(nil), // 15: daemon.DownRequest + (*DownResponse)(nil), // 16: daemon.DownResponse + (*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse + (*PeerState)(nil), // 19: daemon.PeerState + (*LocalPeerState)(nil), // 20: daemon.LocalPeerState + (*SignalState)(nil), // 21: daemon.SignalState + (*ManagementState)(nil), // 22: daemon.ManagementState + (*RelayState)(nil), // 23: daemon.RelayState + (*NSGroupState)(nil), // 24: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 26: daemon.SSHServerState + (*FullStatus)(nil), // 27: daemon.FullStatus + (*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse + (*IPList)(nil), // 32: daemon.IPList + (*Network)(nil), // 33: daemon.Network + (*PortInfo)(nil), // 34: daemon.PortInfo + (*ForwardingRule)(nil), // 35: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse + (*State)(nil), // 43: daemon.State + (*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 52: daemon.TCPFlags + (*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest + (*TraceStage)(nil), // 54: daemon.TraceStage + (*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest + (*SystemEvent)(nil), // 57: daemon.SystemEvent + (*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse + (*Profile)(nil), // 70: daemon.Profile + (*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 73: daemon.LogoutRequest + (*LogoutResponse)(nil), // 74: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse + (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse + nil, // 83: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range + nil, // 85: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 86: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 83, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 24, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 84, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 84, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 83, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 22, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 19, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 18, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 17, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 16, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState - 20, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState - 21, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 54, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 23, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 30, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 80, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 81, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 31, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 31, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 32, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 40, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State - 49, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 51, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 1, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 2, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 84, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 82, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 54, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 83, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 67, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 29, // 32: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 33: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 34: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 35: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 36: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 37: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 38: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 25, // 39: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 27, // 40: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 27, // 41: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 42: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 34, // 43: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 36, // 44: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 38, // 45: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 41, // 46: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 43, // 47: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 45, // 48: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 47, // 49: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 50, // 50: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 53, // 51: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 55, // 52: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 57, // 53: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 59, // 54: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 61, // 55: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 63, // 56: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 65, // 57: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 68, // 58: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 70, // 59: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 72, // 60: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 74, // 61: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 76, // 62: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 78, // 63: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 5, // 64: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 65: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 66: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 67: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 68: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 69: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 26, // 70: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 28, // 71: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 28, // 72: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 33, // 73: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 35, // 74: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 37, // 75: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 39, // 76: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 42, // 77: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 44, // 78: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 46, // 79: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 48, // 80: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 52, // 81: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 54, // 82: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 56, // 83: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 58, // 84: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 60, // 85: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 62, // 86: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 64, // 87: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 66, // 88: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 69, // 89: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 71, // 90: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 73, // 91: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 75, // 92: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 77, // 93: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 79, // 94: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 64, // [64:95] is the sub-list for method output_type - 33, // [33:64] is the sub-list for method input_type - 33, // [33:33] is the sub-list for extension type_name - 33, // [33:33] is the sub-list for extension extendee - 0, // [0:33] is the sub-list for field type_name + 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType + 86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState + 23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState + 24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State + 52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 66, // [66:98] is the sub-list for method output_type + 34, // [34:66] is the sub-list for method input_type + 34, // [34:34] is the sub-list for extension type_name + 34, // [34:34] is the sub-list for extension extendee + 0, // [0:34] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6010,26 +6155,26 @@ func file_daemon_proto_init() { if File_daemon_proto != nil { return } - file_daemon_proto_msgTypes[1].OneofWrappers = []any{} - file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[3].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[28].OneofWrappers = []any{ + file_daemon_proto_msgTypes[9].OneofWrappers = []any{} + file_daemon_proto_msgTypes[30].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[47].OneofWrappers = []any{} - file_daemon_proto_msgTypes[48].OneofWrappers = []any{} - file_daemon_proto_msgTypes[54].OneofWrappers = []any{} + file_daemon_proto_msgTypes[49].OneofWrappers = []any{} + file_daemon_proto_msgTypes[50].OneofWrappers = []any{} file_daemon_proto_msgTypes[56].OneofWrappers = []any{} - file_daemon_proto_msgTypes[67].OneofWrappers = []any{} - file_daemon_proto_msgTypes[73].OneofWrappers = []any{} + file_daemon_proto_msgTypes[58].OneofWrappers = []any{} + file_daemon_proto_msgTypes[69].OneofWrappers = []any{} + file_daemon_proto_msgTypes[75].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 3, - NumMessages: 80, + NumEnums: 4, + NumMessages: 82, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index bf8553706..3dfd3da8d 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -24,7 +24,7 @@ service DaemonService { // Status of the service. rpc Status(StatusRequest) returns (StatusResponse) {} - // Down engine work in the daemon. + // Down stops engine work in the daemon. rpc Down(DownRequest) returns (DownResponse) {} // GetConfig of the daemon. @@ -93,9 +93,26 @@ service DaemonService { // WaitJWTToken waits for JWT authentication completion rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} + + rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} } + +message OSLifecycleRequest { + // avoid collision with loglevel enum + enum CycleType { + UNKNOWN = 0; + SLEEP = 1; + WAKEUP = 2; + } + + CycleType type = 1; +} + +message OSLifecycleResponse {} + + message LoginRequest { // setupKey netbird setup key. string setupKey = 1; diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index b2bf716b2..6b01309b7 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -27,7 +27,7 @@ type DaemonServiceClient interface { Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) // Status of the service. Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) @@ -70,6 +70,7 @@ type DaemonServiceClient interface { RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) } type daemonServiceClient struct { @@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken return out, nil } +func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { + out := new(OSLifecycleResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -395,7 +405,7 @@ type DaemonServiceServer interface { Up(context.Context, *UpRequest) (*UpResponse, error) // Status of the service. Status(context.Context, *StatusRequest) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) @@ -438,6 +448,7 @@ type DaemonServiceServer interface { RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") } +func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(OSLifecycleRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/NotifyOSLifecycle", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "WaitJWTToken", Handler: _DaemonService_WaitJWTToken_Handler, }, + { + MethodName: "NotifyOSLifecycle", + Handler: _DaemonService_NotifyOSLifecycle_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/lifecycle.go b/client/server/lifecycle.go new file mode 100644 index 000000000..3722c027d --- /dev/null +++ b/client/server/lifecycle.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" +) + +// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. +func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { + switch req.GetType() { + case proto.OSLifecycleRequest_WAKEUP: + return s.handleWakeUp(callerCtx) + case proto.OSLifecycleRequest_SLEEP: + return s.handleSleep(callerCtx) + default: + log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) + } + return &proto.OSLifecycleResponse{}, nil +} + +// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep. +// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails. +func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + if !s.sleepTriggeredDown.Load() { + log.Info("skipping up because wasn't sleep down") + return &proto.OSLifecycleResponse{}, nil + } + + // avoid other wakeup runs if sleep didn't make the computer sleep + s.sleepTriggeredDown.Store(false) + + log.Info("running up after wake up") + _, err := s.Up(callerCtx, &proto.UpRequest{}) + if err != nil { + log.Errorf("running up failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + log.Info("running up command executed successfully") + return &proto.OSLifecycleResponse{}, nil +} + +// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state. +func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + s.mutex.Lock() + + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, err + } + + if status != internal.StatusConnecting && status != internal.StatusConnected { + log.Infof("skipping setting the agent down because status is %s", status) + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, nil + } + s.mutex.Unlock() + + log.Info("running down after system started sleeping") + + _, err = s.Down(callerCtx, &proto.DownRequest{}) + if err != nil { + log.Errorf("running down failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + s.sleepTriggeredDown.Store(true) + + log.Info("running down executed successfully") + return &proto.OSLifecycleResponse{}, nil +} diff --git a/client/server/lifecycle_test.go b/client/server/lifecycle_test.go new file mode 100644 index 000000000..a604c60af --- /dev/null +++ b/client/server/lifecycle_test.go @@ -0,0 +1,219 @@ +package server + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" +) + +func newTestServer() *Server { + ctx := internal.CtxInitState(context.Background()) + return &Server{ + rootCtx: ctx, + statusRecorder: peer.NewRecorder(""), + } +} + +func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) { + s := newTestServer() + + // sleepTriggeredDown is false by default + assert.False(t, s.sleepTriggeredDown.Load()) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusIdle) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusNeedsLogin) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnecting) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnected) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected") +} + +func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) { + s := newTestServer() + + // Manually set the flag to simulate prior sleep down + s.sleepTriggeredDown.Store(true) + + // WakeUp will try to call Up which fails without proper setup, but flag should reset first + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt") +} + +func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) { + s := newTestServer() + + // First wakeup without prior sleep - should be no-op + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Simulate prior sleep + s.sleepTriggeredDown.Store(true) + + // First wakeup after sleep - should reset flag + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Second wakeup - should be no-op + resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) +} + +func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) { + s := newTestServer() + + resp, err := s.handleWakeUp(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) +} + +func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) { + s := newTestServer() + s.sleepTriggeredDown.Store(true) + + // Even if Up fails, flag should be reset + _, _ = s.handleWakeUp(context.Background()) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up") +} + +func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Idle", internal.StatusIdle}, + {"NeedsLogin", internal.StatusNeedsLogin}, + {"LoginFailed", internal.StatusLoginFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + resp, err := s.handleSleep(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + }) + } +} + +func TestHandleSleep_ProceedsForActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Connecting", internal.StatusConnecting}, + {"Connected", internal.StatusConnected}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.handleSleep(ctx) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, s.sleepTriggeredDown.Load()) + }) + } +} diff --git a/client/server/server.go b/client/server/server.go index 49000c092..cf79ec5ed 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -85,6 +85,9 @@ type Server struct { profilesDisabled bool updateSettingsDisabled bool + // sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down + sleepTriggeredDown atomic.Bool + jwtCache *jwtCache } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 57d0e74a0..8f99608e7 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1196,25 +1196,27 @@ func (s *serviceClient) handleSleepEvents(event sleep.EventType) { return } + req := &proto.OSLifecycleRequest{} + switch event { case sleep.EventTypeWakeUp: log.Infof("handle wakeup event: %v", event) - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - return - } - return + req.Type = proto.OSLifecycleRequest_WAKEUP case sleep.EventTypeSleep: log.Infof("handle sleep event: %v", event) - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - return - } + req.Type = proto.OSLifecycleRequest_SLEEP + default: + log.Infof("unknown event: %v", event) + return } - log.Info("successfully notified daemon about sleep/wakeup event") + _, err = conn.NotifyOSLifecycle(s.ctx, req) + if err != nil { + log.Errorf("failed to notify daemon about os lifecycle notification: %v", err) + return + } + + log.Info("successfully notified daemon about os lifecycle") } // setSettingsEnabled enables or disables the settings menu based on the provided state From 27dd97c9c40ae889fd1fc936ddc7cdb67d290300 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 3 Dec 2025 14:45:59 +0300 Subject: [PATCH 087/118] [management] Add support to disable geolocation service (#4901) --- management/internals/server/modules.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 91ce50a79..af9ca5f2d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -2,6 +2,7 @@ package server import ( "context" + "os" log "github.com/sirupsen/logrus" @@ -21,7 +22,16 @@ import ( "github.com/netbirdio/netbird/management/server/users" ) +const ( + geolocationDisabledKey = "NB_DISABLE_GEOLOCATION" +) + func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { + if os.Getenv(geolocationDisabledKey) == "true" { + log.Info("geolocation service is disabled, skipping initialization") + return nil + } + return Create(s, func() geolocation.Geolocation { geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { From d2e48d4f5e3dba008393b3cee233ed9877ee1f9a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 3 Dec 2025 18:42:53 +0100 Subject: [PATCH 088/118] [relay] Use instanceURL instead of Exposed address. (#4905) Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields. --- relay/cmd/root.go | 3 ++- relay/healthcheck/healthcheck.go | 18 +++++++---------- relay/healthcheck/ws.go | 12 +++++------ relay/server/relay.go | 16 ++++++--------- relay/server/server.go | 12 ++++------- relay/server/url.go | 22 +++++++++++++++------ relay/server/{relay_test.go => url_test.go} | 9 ++++++--- 7 files changed, 47 insertions(+), 45 deletions(-) rename relay/server/{relay_test.go => url_test.go} (78%) diff --git a/relay/cmd/root.go b/relay/cmd/root.go index eb2cdebf8..e7dadcfdf 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error { log.Debugf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err) } - log.Infof("server will be available on: %s", srv.InstanceURL()) + instanceURL := srv.InstanceURL() + log.Infof("server will be available on: %s", instanceURL.String()) wg.Add(1) go func() { defer wg.Done() diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index 6463843eb..b54d4b33b 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,13 +6,14 @@ import ( "errors" "net" "net/http" - "strings" + "net/url" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server" ) const ( @@ -26,7 +27,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ExposedAddress() string + InstanceURL() url.URL } type HealthStatus struct { @@ -134,7 +135,7 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") { + if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS { status.CertificateValid = false } @@ -156,14 +157,9 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { } func (s *Server) validateConnection(ctx context.Context) bool { - exposedAddress := s.config.ServiceChecker.ExposedAddress() - if exposedAddress == "" { - log.Error("exposed address is empty, cannot validate certificate") - return false - } - - if err := dialWS(ctx, exposedAddress); err != nil { - log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err) + addr := s.config.ServiceChecker.InstanceURL() + if err := dialWS(ctx, addr); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err) return false } diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index badd31219..db61ed802 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,22 +3,22 @@ package healthcheck import ( "context" "fmt" - "strings" + "net/url" "github.com/coder/websocket" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay" ) -func dialWS(ctx context.Context, address string) error { - addressSplit := strings.Split(address, "/") +func dialWS(ctx context.Context, address url.URL) error { scheme := "ws" - if addressSplit[0] == "rels:" { + if address.Scheme == server.SchemeRELS { scheme = "wss" } - url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath) + wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath) - conn, resp, err := websocket.Dial(ctx, url, nil) + conn, resp, err := websocket.Dial(ctx, wsURL, nil) if resp != nil { defer func() { if resp.Body != nil { diff --git a/relay/server/relay.go b/relay/server/relay.go index aab575bf0..c1cfa13fd 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "sync" "time" @@ -22,7 +23,7 @@ type Config struct { TLSSupport bool AuthValidator Validator - instanceURL string + instanceURL url.URL } func (c *Config) validate() error { @@ -37,7 +38,7 @@ func (c *Config) validate() error { if err != nil { return fmt.Errorf("invalid url: %v", err) } - c.instanceURL = instanceURL + c.instanceURL = *instanceURL if c.AuthValidator == nil { return fmt.Errorf("auth validator is required") @@ -53,7 +54,7 @@ type Relay struct { store *store.Store notifier *store.PeerNotifier - instanceURL string + instanceURL url.URL exposedAddress string preparedMsg *preparedMsg @@ -97,7 +98,7 @@ func NewRelay(config Config) (*Relay, error) { notifier: store.NewPeerNotifier(), } - r.preparedMsg, err = newPreparedMsg(r.instanceURL) + r.preparedMsg, err = newPreparedMsg(r.instanceURL.String()) if err != nil { metricsCancel() return nil, fmt.Errorf("prepare message: %v", err) @@ -177,11 +178,6 @@ func (r *Relay) Shutdown(ctx context.Context) { } // InstanceURL returns the instance URL of the relay server -func (r *Relay) InstanceURL() string { +func (r *Relay) InstanceURL() url.URL { return r.instanceURL } - -// ExposedAddress returns the exposed address (domain:port) where clients connect -func (r *Relay) ExposedAddress() string { - return r.exposedAddress -} diff --git a/relay/server/server.go b/relay/server/server.go index 2c9e658d6..8e4333064 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/tls" + "net/url" "sync" "github.com/hashicorp/go-multierror" @@ -39,7 +40,7 @@ type Server struct { // // config: A Config struct containing the necessary configuration: // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used. -// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required. +// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required. // - TLSSupport: A boolean indicating whether TLS is enabled for the server. // - AuthValidator: A Validator used to authenticate peers. Required. // @@ -119,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error { return nberrors.FormatErrorOrNil(multiErr) } -// InstanceURL returns the instance URL of the relay server. -func (r *Server) InstanceURL() string { - return r.relay.instanceURL -} - func (r *Server) ListenerProtocols() []protocol.Protocol { result := make([]protocol.Protocol, 0) @@ -135,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ExposedAddress() string { - return r.relay.ExposedAddress() +func (r *Server) InstanceURL() url.URL { + return r.relay.InstanceURL() } diff --git a/relay/server/url.go b/relay/server/url.go index 9cbf44642..aeae1c068 100644 --- a/relay/server/url.go +++ b/relay/server/url.go @@ -6,9 +6,14 @@ import ( "strings" ) +const ( + SchemeREL = "rel" + SchemeRELS = "rels" +) + // getInstanceURL checks if user supplied a URL scheme otherwise adds to the // provided address according to TLS definition and parses the address before returning it -func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { +func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) { addr := exposedAddress split := strings.Split(exposedAddress, "://") switch { @@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { case len(split) == 1 && !tlsSupported: addr = "rel://" + exposedAddress case len(split) > 2: - return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress) } parsedURL, err := url.ParseRequestURI(addr) if err != nil { - return "", fmt.Errorf("invalid exposed address: %v", err) + return nil, fmt.Errorf("invalid exposed address: %v", err) } - if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { - return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS { + return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) } - return parsedURL.String(), nil + // Validate scheme matches TLS configuration + if tlsSupported && parsedURL.Scheme == SchemeREL { + return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL) + } + + return parsedURL, nil } diff --git a/relay/server/relay_test.go b/relay/server/url_test.go similarity index 78% rename from relay/server/relay_test.go rename to relay/server/url_test.go index 062039ab9..ca455f45a 100644 --- a/relay/server/relay_test.go +++ b/relay/server/url_test.go @@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) { {"Valid address with TLS", "example.com", true, "rels://example.com", false}, {"Valid address without TLS", "example.com", false, "rel://example.com", false}, {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false}, - {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false}, + {"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true}, {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false}, {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false}, {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false}, @@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) { if (err != nil) != tt.expectError { t.Errorf("expected error: %v, got: %v", tt.expectError, err) } - if url != tt.expectedURL { - t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url) + if !tt.expectError && url != nil && url.String() != tt.expectedURL { + t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String()) + } + if tt.expectError && url != nil { + t.Errorf("expected nil URL on error, got: %s", url.String()) } }) } From 031ab1117800fec893f4c433cbd1c3999882e6b4 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 4 Dec 2025 16:57:29 +0300 Subject: [PATCH 089/118] [client] Remove select account prompt (#4912) Signed-off-by: bcmmbaga --- client/internal/auth/pkce_flow.go | 3 +-- client/internal/auth/pkce_flow_test.go | 13 ++++++------- shared/management/client/common/types.go | 8 ++++---- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index c03376f5b..cc43c8648 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -107,10 +107,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn if !p.providerConfig.DisablePromptLogin { switch p.providerConfig.LoginFlag { case common.LoginFlagPromptLogin: - params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) + params = append(params, oauth2.SetAuthURLParam("prompt", "login")) case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) - params = append(params, oauth2.SetAuthURLParam("prompt", "select_account")) } } if p.providerConfig.LoginHint != "" { diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b5843f104..b77a17eaa 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -15,9 +15,8 @@ import ( func TestPromptLogin(t *testing.T) { const ( - promptSelectAccountLogin = "prompt=login+select_account" - promptSelectAccount = "prompt=select_account" - maxAge0 = "max_age=0" + promptLogin = "prompt=login" + maxAge0 = "max_age=0" ) tt := []struct { @@ -27,14 +26,14 @@ func TestPromptLogin(t *testing.T) { expectContains []string }{ { - name: "Prompt login with select account", + name: "Prompt login", loginFlag: mgm.LoginFlagPromptLogin, - expectContains: []string{promptSelectAccountLogin}, + expectContains: []string{promptLogin}, }, { - name: "Max age 0 with select account", + name: "Max age 0", loginFlag: mgm.LoginFlagMaxAge0, - expectContains: []string{maxAge0, promptSelectAccount}, + expectContains: []string{maxAge0}, }, { name: "Disable prompt login", diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 550bcde30..451578358 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -6,14 +6,14 @@ package common // // | Value | Flag | OAuth Parameters | // |-------|----------------------|-----------------------------------------| -// | 0 | LoginFlagPromptLogin | prompt=select_account login | -// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | +// | 0 | LoginFlagPromptLogin | prompt=login | +// | 1 | LoginFlagMaxAge0 | max_age=0 | type LoginFlag uint8 const ( - // LoginFlagPromptLogin adds prompt=select_account login to the authorization request + // LoginFlagPromptLogin adds prompt=login to the authorization request LoginFlagPromptLogin LoginFlag = iota - // LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request + // LoginFlagMaxAge0 adds max_age=0 to the authorization request LoginFlagMaxAge0 // LoginFlagNone disables all login flags LoginFlagNone From 9bdc4908fb7a32001cebaa03ac754f3aa82a5e4c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:16:38 +0100 Subject: [PATCH 090/118] [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) --- client/firewall/nftables/router_linux.go | 279 ++++++++++++++++++----- 1 file changed, 225 insertions(+), 54 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e4debc179..7f95992da 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -27,7 +27,11 @@ import ( ) const ( - tableNat = "nat" + tableNat = "nat" + tableMangle = "mangle" + tableRaw = "raw" + tableSecurity = "security" + chainNameNatPrerouting = "PREROUTING" chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" @@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - log.Warnf("failed to load filter table, skipping accept rules: %v", err) + log.Debugf("ip filter table not found: %v", err) } return r, nil @@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { return nil, errFilterTableNotFound } +func hookName(hook *nftables.ChainHook) string { + if hook == nil { + return "unknown" + } + switch *hook { + case *nftables.ChainHookForward: + return chainNameForward + case *nftables.ChainHookInput: + return chainNameInput + default: + return fmt.Sprintf("hook(%d)", *hook) + } +} + +func familyName(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + default: + return fmt.Sprintf("family(%d)", family) + } +} + func (r *router) createContainers() error { r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, @@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error { // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { + var merr *multierror.Error + + if err := r.acceptFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) acceptFilterTableRules() error { if r.filterTable == nil { - log.Debugf("table 'filter' not found for forward rules, skipping accept rules") return nil } @@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error { // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { - // filter table exists but iptables is not + // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptFilterRulesNftables() + return r.acceptFilterRulesNftables(r.filterTable) } return r.acceptFilterRulesIptables(ipt) @@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err)) } else { log.Debugf("added iptables forward rule: %v", rule) } @@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { inputRule := r.getAcceptInputRule() if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err)) } else { log.Debugf("added iptables input rule: %v", inputRule) } @@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string { return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} } -func (r *router) acceptFilterRulesNftables() error { +// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables. +// This is used when iptables is not available. +func (r *router) acceptFilterRulesNftables(table *nftables.Table) error { intf := ifname(r.wgIface.Name()) + forwardChain := &nftables.Chain{ + Name: chainNameForward, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + } + r.insertForwardAcceptRules(forwardChain, intf) + + inputChain := &nftables.Chain{ + Name: chainNameInput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + } + r.insertInputAcceptRule(inputChain, intf) + + return r.conn.Flush() +} + +// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables). +// It dynamically finds chains at call time to handle chains that may have been created after startup. +func (r *router) acceptExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + intf := ifname(r.wgIface.Name()) + + for _, chain := range chains { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + continue + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush external chain rules: %w", err) + } + + return nil +} + +func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { iifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error { Data: intf, }, } - oifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) +} +func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { inputRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameInput, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookInput, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error { UserData: []byte(userDataAcceptInputRule), } r.conn.InsertRule(inputRule) - - return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { + var merr *multierror.Error + + if err := r.removeFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.removeExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) removeFilterTableRules() error { if r.filterTable == nil { return nil } ipt, err := iptables.New() if err != nil { - log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptFilterRulesNftables() + log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) } return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptFilterRulesNftables() error { - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) +func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { + chains, err := r.conn.ListChainsOfTableFamily(table.Family) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { + if chain.Table.Name != table.Name { continue } @@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error { continue } - rules, err := r.conn.GetRules(r.filterTable, chain) + if err := r.removeAcceptRulesFromChain(table, chain); err != nil { + return err + } + } + + return r.conn.Flush() +} + +func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error { + rules, err := r.conn.GetRules(table, chain) + if err != nil { + return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err) + } + } + } + return nil +} + +// removeExternalChainsRules removes our accept rules from all external chains. +// This is deterministic - it scans for chains at removal time rather than relying on saved state, +// ensuring cleanup works even after a crash or if chains changed. +func (r *router) removeExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + for _, chain := range chains { + if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil { + log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err) + } + } + + return r.conn.Flush() +} + +// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks. +// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal). +func (r *router) findExternalChains() []*nftables.Chain { + var chains []*nftables.Chain + + families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + + for _, family := range families { + allChains, err := r.conn.ListChainsOfTableFamily(family) if err != nil { - return fmt.Errorf("get rules: %v", err) + log.Debugf("list chains for family %d: %v", family, err) + continue } - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } + for _, chain := range allChains { + if r.isExternalChain(chain) { + chains = append(chains, chain) } } } - if err := r.conn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + return chains +} + +func (r *router) isExternalChain(chain *nftables.Chain) bool { + if r.workTable != nil && chain.Table.Name == r.workTable.Name { + return false } - return nil + // Skip all iptables-managed tables in the ip family + if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + return false + } + + if chain.Type != nftables.ChainTypeFilter { + return false + } + + if chain.Hooknum == nil { + return false + } + + return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput +} + +func isIptablesTable(name string) bool { + switch name { + case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity: + return true + } + return false } func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err)) } } inputRule := r.getAcceptInputRule() if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err)) } return nberrors.FormatErrorOrNil(merr) From 71b6855e0920b377edb203999139056833585534 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 4 Dec 2025 19:51:50 +0100 Subject: [PATCH 091/118] [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891) * Fix engine shutdown deadlock and message handling races - Release syncMsgMux before waiting for shutdownWg to prevent deadlock - Check context inside lock in handleSync and receiveSignalEvents - Prevents nil pointer access when messages arrive during engine stop --- client/internal/connect.go | 13 ++++++++----- client/internal/engine.go | 31 +++++++++++++++++++++---------- client/server/server.go | 2 ++ 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 5a5f4f63c..e9d422a28 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) - c.engine.SetSyncResponsePersistence(c.persistSyncResponse) + engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) + engine.SetSyncResponsePersistence(c.persistSyncResponse) + c.engine = engine c.engineMutex.Unlock() - if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { + if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan <-engineCtx.Done() c.engineMutex.Lock() - engine := c.engine c.engine = nil c.engineMutex.Unlock() - if engine != nil && engine.wgInterface != nil { + // todo: consider to remove this condition. Is not thread safe. + // We should always call Stop(), but we need to verify that it is idempotent + if engine.wgInterface != nil { log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 84bb8beea..ce46dac8c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -280,7 +280,6 @@ func (e *Engine) Stop() error { return nil } e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() if e.connMgr != nil { e.connMgr.Close() @@ -295,7 +294,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { log.Info("removing peers before dns") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } if err := e.stopSSHServer(); err != nil { @@ -319,7 +318,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { log.Info("removing peers before routes") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } @@ -338,7 +337,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { log.Info("removing peers after dns and routes") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } @@ -353,16 +352,18 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer stateCancel() - if err := e.stateManager.Stop(ctx); err != nil { - return fmt.Errorf("failed to stop state manager: %w", err) + if err := e.stateManager.Stop(stateCtx); err != nil { + log.Errorf("failed to stop state manager: %v", err) } if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } + e.syncMsgMux.Unlock() + timeout := e.calculateShutdownTimeout() log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) @@ -448,8 +449,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) if err != nil { return fmt.Errorf("create rosenpass manager: %w", err) } - err := e.rpManager.Run() - if err != nil { + if err := e.rpManager.Run(); err != nil { return fmt.Errorf("run rosenpass manager: %w", err) } } @@ -501,6 +501,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) } if err := e.createFirewall(); err != nil { + e.close() return err } @@ -766,6 +767,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + if update.GetNetbirdConfig() != nil { wCfg := update.GetNetbirdConfig() err := e.updateTURNs(wCfg.GetTurns()) @@ -1385,6 +1391,11 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + conn, ok := e.peerStore.PeerConn(msg.Key) if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) diff --git a/client/server/server.go b/client/server/server.go index cf79ec5ed..d33595115 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes defer s.mutex.Unlock() if err := s.cleanupConnection(); err != nil { + // todo review to update the status in case any type of error log.Errorf("failed to shut down properly: %v", err) return nil, err } @@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe } if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { + // todo review to update the status in case any type of error log.Errorf("failed to cleanup connection: %v", err) return nil, err } From cb6b086164608d66c1dbaf33adaa2e4d9a6b1905 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 4 Dec 2025 21:01:22 +0100 Subject: [PATCH 092/118] [client] Reorder subsystem shutdown so peer removal goes first (#4914) Remove peers before DNS and routes --- client/internal/engine.go | 42 +++++++++++++-------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index ce46dac8c..ff1cec19a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -291,21 +291,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") - if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { - log.Info("removing peers before dns") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } - } if err := e.stopSSHServer(); err != nil { log.Warnf("failed to stop SSH server: %v", err) } e.cleanupSSHConfig() - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - if e.ingressGatewayMgr != nil { if err := e.ingressGatewayMgr.Close(); err != nil { log.Warnf("failed to cleanup forward rules: %v", err) @@ -313,33 +304,28 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } - e.stopDNSForwarder() + if e.srWatcher != nil { + e.srWatcher.Close() + } - if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { - log.Info("removing peers before routes") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } + log.Info("cleaning up status recorder states") + e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) + e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) + e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + + if err := e.removeAllPeers(); err != nil { + log.Errorf("failed to remove all peers: %s", err) } if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.srWatcher != nil { - e.srWatcher.Close() - } - log.Info("cleaning up status recorder states") - e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) - e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) - e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + e.stopDNSForwarder() - if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { - log.Info("removing peers after dns and routes") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } - } + // stop/restore DNS after peers are closed but before interface goes down + // so dbus and friends don't complain because of a missing interface + e.stopDNSServer() if e.cancel != nil { e.cancel() From f538e6e9ae215d1e4baf48f93bbe6fbcb0f1e303 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:29:27 +0100 Subject: [PATCH 093/118] [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) --- client/ssh/server/command_execution_js.go | 5 +++ client/ssh/server/command_execution_unix.go | 26 ++++++++++++- .../ssh/server/command_execution_windows.go | 5 +++ client/ssh/server/server.go | 4 +- client/ssh/server/userswitching_unix.go | 39 ++++++++++++++++--- 5 files changed, 71 insertions(+), 8 deletions(-) diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 6473f8273..01759a337 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool { return false } +// detectUtilLinuxLogin always returns false on JS/WASM +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + // executeCommandWithPty is not supported on JS/WASM func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { logger.Errorf("PTY command execution not supported on JS/WASM") diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index da059fed9..db1a9bcfe 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "os/user" + "runtime" "strings" "sync" "syscall" @@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool { return supported } +// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils). +// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent. +// See https://bugs.debian.org/1078023 for details. +func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { + if runtime.GOOS != "linux" { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "login", "--version") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("login --version failed (likely shadow-utils): %v", err) + return false + } + + isUtilLinux := strings.Contains(string(output), "util-linux") + log.Debugf("util-linux login detected: %v", isUtilLinux) + return isUtilLinux +} + // createSuCommand creates a command using su -l -c for privilege switching func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { suPath, err := exec.LookPath("su") @@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu return false } - logger.Infof("starting interactive shell: %s", execCmd.Path) + logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " ")) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 37b3ae0ee..998796871 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool { return false } +// detectUtilLinuxLogin always returns false on Windows +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { command := session.RawCommand() diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 44612532b..37763ee0e 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -138,7 +138,8 @@ type Server struct { jwtExtractor *jwt.ClaimsExtractor jwtConfig *JWTConfig - suSupportsPty bool + suSupportsPty bool + loginIsUtilLinux bool } type JWTConfig struct { @@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { } s.suSupportsPty = s.detectSuPtySupport(ctx) + s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx) ln, addrDesc, err := s.createListener(ctx, addr) if err != nil { diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index 06fefabd7..bc1557419 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st switch runtime.GOOS { case "linux": - // Special handling for Arch Linux without /etc/pam.d/remote - if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { - return loginPath, []string{"-f", username, "-p"}, nil - } - return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil + p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String()) + return p, a, nil case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil default: @@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st } } -// fileExists checks if a file exists (helper for login command logic) +// getLinuxLoginCmd returns the login command for Linux systems. +// Handles differences between util-linux and shadow-utils login implementations. +func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) { + // Special handling for Arch Linux without /etc/pam.d/remote + var loginArgs []string + if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { + loginArgs = []string{"-f", username, "-p"} + } else { + loginArgs = []string{"-f", username, "-h", remoteIP, "-p"} + } + + // util-linux login requires setsid -c to create a new session and set the + // controlling terminal. Without this, vhangup() kills the parent process. + // See https://bugs.debian.org/1078023 for details. + // TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec() + // to avoid external setsid dependency. + if !s.loginIsUtilLinux { + return loginPath, loginArgs + } + + setsidPath, err := exec.LookPath("setsid") + if err != nil { + log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err) + return loginPath, loginArgs + } + + args := append([]string{"-w", "-c", loginPath}, loginArgs...) + return setsidPath, args +} + +// fileExists checks if a file exists func (s *Server) fileExists(path string) bool { _, err := os.Stat(path) return err == nil From 3f4f825ec14c45dd149a743eb80e98e7bb5e9ec5 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:42:49 +0100 Subject: [PATCH 094/118] [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) --- client/internal/dnsfwd/forwarder.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index aef16a8cf..6b8042ccb 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + // Unmap IPv4-mapped IPv6 addresses that some resolvers may return + for i, ip := range ips { + ips[i] = ip.Unmap() + } + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) f.cache.set(domain, question.Qtype, ips) From 44851e06fbaaf02f2cde642c40ff6cecea0a3c3f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:26:51 +0100 Subject: [PATCH 095/118] [management] cleanup logs (#4933) --- management/internals/shared/grpc/server.go | 29 +-------- management/internals/shared/grpc/token_mgr.go | 6 +- .../server/http/middleware/auth_middleware.go | 2 +- management/server/peer.go | 5 +- management/server/posture/os_version.go | 4 +- management/server/settings/manager.go | 8 --- management/server/store/sql_store.go | 63 +------------------ 7 files changed, 9 insertions(+), 108 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 62dc215d8..16950db5e 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser } log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start)) - }() // todo introduce something more meaningful with the key expiration/rotation if s.appMetrics != nil { @@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) - // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) @@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } }() log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) - log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) @@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) + log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { @@ -561,16 +554,10 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) - defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } - took := time.Since(reqStart) - if took > 7*time.Second { - log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) - } }() if loginReq.GetMeta() == nil { @@ -604,16 +591,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto return nil, mapError(ctx, err) } - log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) return nil, status.Errorf(codes.Internal, "failed logging in peer") } - log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) - key, err := s.secretsManager.GetWGKey() if err != nil { log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) @@ -730,12 +713,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer return status.Errorf(codes.Internal, "error handling request") } - sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) - log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) @@ -750,10 +731,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer // which will be used by our clients to Login func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { @@ -813,10 +790,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr // which will be used by our clients to Login func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index 0f893ae3a..ccb32202f 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI relayCancel := make(chan struct{}, 1) m.relayCancelMap[peerID] = relayCancel go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel) - log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID) } } @@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID) return case <-ticker.C: m.pushNewTURNAndRelayTokens(ctx, accountID, peerID) @@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID) return case <-ticker.C: m.pushNewRelayTokens(ctx, accountID, peerID) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index ffd7e0b93..38cf0c290 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] } if userAuth.AccountId != accountId { - log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) + log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) userAuth.AccountId = accountId } diff --git a/management/server/peer.go b/management/server/peer.go index f2de05f15..49f5bf2a5 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio } } - log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) if err != nil { @@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index 411f4c2c6..2ef97a066 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error { func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index f16b609f8..2b2896572 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,9 +5,6 @@ package settings import ( "context" "fmt" - "time" - - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) - }() - if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 94b7fc1cc..4aa3de499 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -27,7 +27,6 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" - nbcontext "github.com/netbirdio/netbird/management/server/context" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) + log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds()) return err } @@ -583,9 +582,6 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2152,9 +2148,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2171,9 +2164,6 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2229,9 +2219,6 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var user types.User result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { @@ -2491,9 +2478,6 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2514,9 +2498,6 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - result := s.db.WithContext(ctx).Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ @@ -2537,9 +2518,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var groupID string _ = s.db.WithContext(ctx).Model(types.Group{}). Select("id"). @@ -2569,9 +2547,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer // AddPeerToGroup adds a peer to a group func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - peer := &types.GroupPeer{ AccountID: accountID, GroupID: groupID, @@ -2768,9 +2743,6 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -2897,9 +2869,6 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) @@ -4022,36 +3991,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin return groupPeers, nil } -func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID) - } - - requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID) - } - - accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) - } - - go func() { - select { - case <-ctx.Done(): - case <-grpcCtx.Done(): - log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err()) - } - }() - return ctx, cancel -} - func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { var info types.PrimaryAccountInfo result := s.db.Model(&types.Account{}). From 94d34dc0c5cc9ea353438adfbe4fd613566be2b2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:29:15 +0100 Subject: [PATCH 096/118] [management] monitoring updates (#4937) --- management/internals/shared/grpc/server.go | 13 ++++--------- management/server/account.go | 7 +++++++ management/server/telemetry/grpc_metrics.go | 18 ------------------ .../server/telemetry/http_api_metrics.go | 12 ++++++++++++ 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 16950db5e..6029dc8bf 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -171,8 +171,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } s.syncSem.Add(1) - reqStart := time.Now() - ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -190,7 +188,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { s.syncSem.Add(-1) @@ -263,10 +261,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) - } - unlock() unlock = nil @@ -518,7 +512,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -530,7 +523,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto metahashed := metaHash(peerMeta, sRealIP) if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() @@ -554,6 +547,8 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) diff --git a/management/server/account.go b/management/server/account.go index dac040db0..8a2155cb4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -787,6 +787,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) + if ctx == nil { + ctx = context.Background() + } + + // nolint:staticcheck + ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) if err != nil { return nil, nil, err diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index d4301802f..4ba1ee14e 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -16,7 +16,6 @@ type GRPCMetrics struct { meter metric.Meter syncRequestsCounter metric.Int64Counter syncRequestsBlockedCounter metric.Int64Counter - syncRequestHighLatencyCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter loginRequestsBlockedCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter @@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), - ) - if err != nil { - return nil, err - } - loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro meter: meter, syncRequestsCounter: syncRequestsCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter, - syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, loginRequestsCounter: loginRequestsCounter, loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, @@ -172,14 +162,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration } } -// CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { - grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) - if duration > HighLatencyThreshold { - grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) - } -} - // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error { _, err := grpcMetrics.meter.RegisterCallback( diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index ae27466d9..0b6f8beb6 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, r.WithContext(ctx)) + userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) + if err == nil { + if userAuth.AccountId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) + } + if userAuth.UserId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) + } + } + if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } else { From 90e3b8009fb5228a8643143d775eb141f4abba91 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 20:11:12 +0100 Subject: [PATCH 097/118] [management] Fix sync metrics (#4939) --- management/internals/shared/grpc/server.go | 6 ++++++ management/server/telemetry/grpc_metrics.go | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6029dc8bf..462e2e6eb 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -171,6 +171,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } s.syncSem.Add(1) + reqStart := time.Now() + ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -261,6 +263,10 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + } + unlock() unlock = nil diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index 4ba1ee14e..bd7fbc235 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -162,6 +162,11 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration } } +// CountSyncRequestDuration counts the duration of the sync gRPC requests +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { + grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) +} + // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error { _, err := grpcMetrics.meter.RegisterCallback( From abcbde26f9294f5bd649386530c21d816bca02be Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:45:47 +0100 Subject: [PATCH 098/118] [management] remove context from store methods (#4940) --- management/server/store/sql_store.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4aa3de499..74b03ce48 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -588,7 +588,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) + result := tx.Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -2154,7 +2154,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } var accountNetwork types.AccountNetwork - if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -2170,7 +2170,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } var peer nbpeer.Peer - result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2220,7 +2220,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user types.User - result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) @@ -2484,7 +2484,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } var setupKey types.SetupKey - result := tx.WithContext(ctx). + result := tx. Take(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { @@ -2498,7 +2498,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.WithContext(ctx).Model(&types.SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -2519,7 +2519,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { var groupID string - _ = s.db.WithContext(ctx).Model(types.Group{}). + _ = s.db.Model(types.Group{}). Select("id"). Where("account_id = ? AND name = ?", accountID, "All"). Limit(1). @@ -2743,7 +2743,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -2869,7 +2869,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -4030,7 +4030,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i Network: &types.Network{Net: ipNet}, } - result := s.db.WithContext(ctx). + result := s.db. Model(&types.Account{}). Where(idQueryCondition, accountID). Updates(&patch) From 932c02eaabbcc3d314013969d85cd1aee353af7b Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 12 Dec 2025 18:49:57 +0300 Subject: [PATCH 099/118] [management] Approve all pending peers when peer approval is disabled (#4806) --- go.mod | 2 +- go.sum | 4 +- management/server/account.go | 29 +++---- management/server/account_test.go | 37 +++++++++ management/server/integrated_validator.go | 2 +- .../integrated_validator/interface.go | 2 +- management/server/store/sql_store.go | 12 +++ management/server/store/sql_store_test.go | 77 +++++++++++++++++++ management/server/store/store.go | 1 + 9 files changed, 148 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index 870118c88..8f4ec530b 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba + github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 96a303f79..f10e1e6da 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account.go b/management/server/account.go index 8a2155cb4..a9becc4b6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } - if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil { return err } + if oldSettings.Extra != nil && newSettings.Extra != nil && + oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { + approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to approve pending peers: %w", err) + } + + if approvedCount > 0 { + log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID) + updateAccountPeers = true + } + } + if oldSettings.NetworkRange != newSettings.NetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err @@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } -func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") - if err != nil { - return err - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 8569f1b2f..7f125e3a0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + accountID := account.Id + userID := account.Users[account.CreatedBy].Id + ctx := context.Background() + + newSettings := account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: true, + } + _, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + peer1.Status.RequiresApproval = true + peer2.Status.RequiresApproval = true + peer3.Status.RequiresApproval = false + + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3)) + + newSettings = account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: false, + } + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range accountPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID) + } +} + func TestAccount_GetExpiredPeers(t *testing.T) { type test struct { name string diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index e9a1c8701..69ea668ad 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -127,7 +127,7 @@ type MockIntegratedValidator struct { ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) } -func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error { return nil } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index 26c338cb6..326fbfaf0 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -10,7 +10,7 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 74b03ce48..2b8981b97 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -412,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW return nil } +// ApproveAccountPeers marks all peers that currently require approval in the given account as approved. +func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) { + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true). + Update("peer_status_requires_approval", false) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error) + } + + return int(result.RowsAffected), nil +} + // SaveUsers saves the given list of users to the database. func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { if len(users) == 0 { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index d40c4664c..2e2623910 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { }) } } + +func TestSqlStore_ApproveAccountPeers(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + accountID := "test-account" + ctx := context.Background() + + account := newAccountWithId(ctx, accountID, "testuser", "example.com") + err := store.SaveAccount(ctx, account) + require.NoError(t, err) + + peers := []*nbpeer.Peer{ + { + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.netbird.cloud", + Key: "peer1-key", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer2", + AccountID: accountID, + DNSLabel: "peer2.netbird.cloud", + Key: "peer2-key", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer3", + AccountID: accountID, + DNSLabel: "peer3.netbird.cloud", + Key: "peer3-key", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: false, + LastSeen: time.Now().UTC(), + }, + }, + } + + for _, peer := range peers { + err = store.AddPeerToAccount(ctx, peer) + require.NoError(t, err) + } + + t.Run("approve all pending peers", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 2, count) + + allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range allPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID) + } + }) + + t.Run("no peers to approve", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("non-existent account", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, "non-existent") + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 007e2b739..0ec7949f9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -143,6 +143,7 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error + ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) From 08f31fbcb3198f309fb52a70a48d99d47ca35b83 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Fri, 12 Dec 2025 14:29:58 -0300 Subject: [PATCH 100/118] [iOS] Add force relay connection on iOS (#4928) * [ios] Add a bogus test to check iOS behavior when setting environment variables * [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables" This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b. * [ios] Add EnvList struct to export and import environment variables * [ios] Add envList parameter to the iOS Client Run method * [ios] Add some debug logging to exportEnvVarList * Add "//go:build ios" to client/ios/NetBirdSDK files --- client/ios/NetBirdSDK/client.go | 22 ++++++++++++++- client/ios/NetBirdSDK/env_list.go | 34 +++++++++++++++++++++++ client/ios/NetBirdSDK/gomobile.go | 2 ++ client/ios/NetBirdSDK/logger.go | 2 ++ client/ios/NetBirdSDK/login.go | 2 ++ client/ios/NetBirdSDK/peer_notifier.go | 2 ++ client/ios/NetBirdSDK/preferences.go | 2 ++ client/ios/NetBirdSDK/preferences_test.go | 2 ++ client/ios/NetBirdSDK/routes.go | 2 ++ 9 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 client/ios/NetBirdSDK/env_list.go diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6d969bb12..463c93d57 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -1,9 +1,12 @@ +//go:build ios + package NetBirdSDK import ( "context" "fmt" "net/netip" + "os" "sort" "strings" "sync" @@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s } // Run start the internal client. It is a blocker function -func (c *Client) Run(fd int32, interfaceName string) error { +func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { + exportEnvList(envList) log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ @@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID { } return netIDs } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k)) + log.Debugf("Setting env variable %s: %s", k, v) + + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } else { + log.Debugf("Env variable %s was set successfully", k) + } + } +} diff --git a/client/ios/NetBirdSDK/env_list.go b/client/ios/NetBirdSDK/env_list.go new file mode 100644 index 000000000..4800803d7 --- /dev/null +++ b/client/ios/NetBirdSDK/env_list.go @@ -0,0 +1,34 @@ +//go:build ios + +package NetBirdSDK + +import "github.com/netbirdio/netbird/client/internal/peer" + +// EnvList is an exported struct to be bound by gomobile +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} + +// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client +func GetEnvKeyNBForceRelay() string { + return peer.EnvKeyNBForceRelay +} diff --git a/client/ios/NetBirdSDK/gomobile.go b/client/ios/NetBirdSDK/gomobile.go index 9eadd6a7f..79bf0c2ac 100644 --- a/client/ios/NetBirdSDK/gomobile.go +++ b/client/ios/NetBirdSDK/gomobile.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import _ "golang.org/x/mobile/bind" diff --git a/client/ios/NetBirdSDK/logger.go b/client/ios/NetBirdSDK/logger.go index f1ad1b9f6..531d0ba89 100644 --- a/client/ios/NetBirdSDK/logger.go +++ b/client/ios/NetBirdSDK/logger.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 570c44f80..1c2b38a61 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index 16c5039eb..9b00568be 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 5e7050465..39ae06538 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index 780443a7b..5f75e7c9a 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 30d0d0d0a..7b84d6e1c 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // RoutesSelectionInfoCollection made for Java layer to get non default types as collection From 5748bdd64edf0a0e77dab2311ceb700d049730c8 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Dec 2025 10:28:25 +0100 Subject: [PATCH 101/118] Add health-check agent recognition to avoid error logs (#4917) Health-check connections now send a properly formatted auth message with a well-known peer ID instead of immediately closing. The server recognizes this peer ID and handles the connection gracefully with a debug log instead of error logs. --- relay/healthcheck/peerid/peerid.go | 31 ++++++++++++++++++++++++++++++ relay/healthcheck/ws.go | 15 ++++++++++++++- relay/server/handshake.go | 4 ++-- relay/server/relay.go | 7 ++++++- 4 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 relay/healthcheck/peerid/peerid.go diff --git a/relay/healthcheck/peerid/peerid.go b/relay/healthcheck/peerid/peerid.go new file mode 100644 index 000000000..cd8696817 --- /dev/null +++ b/relay/healthcheck/peerid/peerid.go @@ -0,0 +1,31 @@ +package peerid + +import ( + "crypto/sha256" + + v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" + "github.com/netbirdio/netbird/shared/relay/messages" +) + +var ( + // HealthCheckPeerID is the hashed peer ID for health check connections + HealthCheckPeerID = messages.HashID("healthcheck-agent") + + // DummyAuthToken is a structurally valid auth token for health check. + // The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload). + DummyAuthToken = createDummyToken() +) + +func createDummyToken() []byte { + token := v2.Token{ + AuthAlgo: v2.AuthAlgoHMACSHA256, + Signature: make([]byte, sha256.Size), + Payload: []byte("healthcheck"), + } + return token.Marshal() +} + +// IsHealthCheck checks if the given peer ID is the health check agent +func IsHealthCheck(peerID *messages.PeerID) bool { + return peerID != nil && *peerID == HealthCheckPeerID +} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index db61ed802..9267096f5 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -7,8 +7,10 @@ import ( "github.com/coder/websocket" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay" + "github.com/netbirdio/netbird/shared/relay/messages" ) func dialWS(ctx context.Context, address url.URL) error { @@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error { if err != nil { return fmt.Errorf("failed to connect to websocket: %w", err) } + defer func() { + _ = conn.CloseNow() + }() + + authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken) + if err != nil { + return fmt.Errorf("failed to marshal auth message: %w", err) + } + + if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil { + return fmt.Errorf("failed to write auth message: %w", err) + } - _ = conn.Close(websocket.StatusNormalClosure, "availability check complete") return nil } diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 922369798..8c3ee1899 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } if err != nil { - return nil, err + return peerID, err } h.peerID = peerID return peerID, nil @@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) { } if err := h.validator.Validate(authPayload); err != nil { - return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) + return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) } return rawPeerID, nil diff --git a/relay/server/relay.go b/relay/server/relay.go index c1cfa13fd..bb355f58f 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" @@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) { } peerID, err := h.handshakeReceive() if err != nil { - log.Errorf("failed to handshake: %s", err) + if peerid.IsHealthCheck(peerID) { + log.Debugf("health check connection from %s", conn.RemoteAddr()) + } else { + log.Errorf("failed to handshake: %s", err) + } if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } From 447cd287f5ef17c12242125ed154ddfa1dd82acf Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Dec 2025 10:34:48 +0100 Subject: [PATCH 102/118] [ci] Add local lint setup with pre-push hook to catch issues early (#4925) * Add local lint setup with pre-push hook to catch issues early Developers can now catch lint issues before pushing, reducing CI failures and iteration time. The setup uses golangci-lint locally with the same configuration as CI. Setup: - Run `make setup-hooks` once after cloning - Pre-push hook automatically lints changed files (~90s) - Use `make lint` to manually check changed files - Use `make lint-all` to run full CI-equivalent lint The Makefile auto-installs golangci-lint to ./bin/ using go install to match the Go version in go.mod, avoiding version compatibility issues. --------- Co-authored-by: mlsmaycon --- .githooks/pre-push | 11 +++++++++++ CONTRIBUTING.md | 8 ++++++++ Makefile | 27 +++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100755 .githooks/pre-push create mode 100644 Makefile diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 000000000..31898182e --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,11 @@ +#!/bin/bash + +echo "Running pre-push hook..." +if ! make lint; then + echo "" + echo "Hint: To push without verification, run:" + echo " git push --no-verify" + exit 1 +fi + +echo "All checks passed!" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c82cfc763..efc7d9460 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -136,6 +136,14 @@ checked out and set up: go mod tidy ``` +6. Configure Git hooks for automatic linting: + + ```bash + make setup-hooks + ``` + + This will configure Git to run linting automatically before each push, helping catch issues early. + ### Dev Container Support If you prefer using a dev container for development, NetBird now includes support for dev containers. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..43379e115 --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +.PHONY: lint lint-all lint-install setup-hooks +GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint + +# Install golangci-lint locally if needed +$(GOLANGCI_LINT): + @echo "Installing golangci-lint..." + @mkdir -p ./bin + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +# Lint only changed files (fast, for pre-push) +lint: $(GOLANGCI_LINT) + @echo "Running lint on changed files..." + @$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m + +# Lint entire codebase (slow, matches CI) +lint-all: $(GOLANGCI_LINT) + @echo "Running lint on all files..." + @$(GOLANGCI_LINT) run --timeout=12m + +# Just install the linter +lint-install: $(GOLANGCI_LINT) + +# Setup git hooks for all developers +setup-hooks: + @git config core.hooksPath .githooks + @chmod +x .githooks/pre-push + @echo "✅ Git hooks configured! Pre-push will now run 'make lint'" From c29bb1a289d6e8688e3d994ddfc8d40f25a8babd Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:02:37 +0100 Subject: [PATCH 103/118] [management] use xid as request id for logging (#4955) --- formatter/hook/hook.go | 35 ++----------------- management/internals/server/boot.go | 6 ++-- .../server/telemetry/http_api_metrics.go | 4 +-- 3 files changed, 7 insertions(+), 38 deletions(-) diff --git a/formatter/hook/hook.go b/formatter/hook/hook.go index c0d8c4eba..f0ee509f8 100644 --- a/formatter/hook/hook.go +++ b/formatter/hook/hook.go @@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error { entry.Data["context"] = source - switch source { - case HTTPSource: - addHTTPFields(entry) - case GRPCSource: - addGRPCFields(entry) - case SystemSource: - addSystemFields(entry) - } + addFields(entry) return nil } @@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string { return fmt.Sprintf("%s/%s", pkg, file) } -func addHTTPFields(entry *logrus.Entry) { +func addFields(entry *logrus.Entry) { if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { entry.Data[context.RequestIDKey] = ctxReqID } @@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) { if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { entry.Data[context.UserIDKey] = ctxInitiatorID } -} - -func addGRPCFields(entry *logrus.Entry) { - if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { - entry.Data[context.RequestIDKey] = ctxReqID - } - if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { - entry.Data[context.AccountIDKey] = ctxAccountID - } - if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { - entry.Data[context.PeerIDKey] = ctxDeviceID - } -} - -func addSystemFields(entry *logrus.Entry) { - if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { - entry.Data[context.RequestIDKey] = ctxReqID - } - if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { - entry.Data[context.UserIDKey] = ctxInitiatorID - } - if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { - entry.Data[context.AccountIDKey] = ctxAccountID - } if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { entry.Data[context.PeerIDKey] = ctxDeviceID } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 37788e80e..57b3fac78 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -10,9 +10,9 @@ import ( "slices" "time" - "github.com/google/uuid" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -180,7 +180,7 @@ func unaryInterceptor( info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { - reqID := uuid.New().String() + reqID := xid.New().String() //nolint ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource) //nolint @@ -194,7 +194,7 @@ func streamInterceptor( info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - reqID := uuid.New().String() + reqID := xid.New().String() wrapped := grpcMiddleware.WrapServerStream(ss) //nolint ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource) diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index 0b6f8beb6..c50ed1e51 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -7,8 +7,8 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/gorilla/mux" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { //nolint ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource) - reqID := uuid.New().String() + reqID := xid.New().String() //nolint ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) From a9c28ef723c790792124209b02a5131087dda2b4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 17 Dec 2025 13:49:02 +0100 Subject: [PATCH 104/118] Add stack trace for bundle (#4957) --- client/internal/debug/debug.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 58977b884..3c201ecfc 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -56,6 +56,7 @@ block.prof: Block profiling information. heap.prof: Heap profiling information (snapshot of memory allocations). allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. +stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. Anonymization Process @@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof This will open a web browser tab with the profiling information. +Stack Trace +The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created. + Routes The routes.txt file contains detailed routing table information in a tabular format: @@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add profiles to debug bundle: %v", err) } + if err := g.addStackTrace(); err != nil { + log.Errorf("failed to add stack trace to debug bundle: %v", err) + } + if err := g.addSyncResponse(); err != nil { return fmt.Errorf("add sync response: %w", err) } @@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) { return nil } +func (g *BundleGenerator) addStackTrace() error { + buf := make([]byte, 5242880) // 5 MB buffer + n := runtime.Stack(buf, true) + + stackTrace := bytes.NewReader(buf[:n]) + if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil { + return fmt.Errorf("add stack trace file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addInterfaces() error { interfaces, err := net.Interfaces() if err != nil { From 537151e0f365bfd9c94e8fa5b9c88b7cb1b78a02 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 17 Dec 2025 13:55:33 +0100 Subject: [PATCH 105/118] Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) --- client/internal/peer/endpoint.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go index 39cb95591..52d66159c 100644 --- a/client/internal/peer/endpoint.go +++ b/client/internal/peer/endpoint.go @@ -20,7 +20,7 @@ type EndpointUpdater struct { wgConfig WgConfig initiator bool - // mu protects updateWireGuardPeer and cancelFunc + // mu protects cancelFunc mu sync.Mutex cancelFunc func() updateWg sync.WaitGroup @@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U case <-ctx.Done(): return case <-t.C: - e.mu.Lock() if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) } - e.mu.Unlock() } } From 011cc81678833f49c24cab8abff960ed7635c9b6 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 19 Dec 2025 19:57:39 +0100 Subject: [PATCH 106/118] [client, management] auto-update (#4732) --- client/android/client.go | 4 +- client/cmd/root.go | 3 + client/cmd/signer/artifactkey.go | 176 +++ client/cmd/signer/artifactsign.go | 276 +++++ client/cmd/signer/main.go | 21 + client/cmd/signer/revocation.go | 220 ++++ client/cmd/signer/rootkey.go | 74 ++ client/cmd/up.go | 2 +- client/cmd/update.go | 13 + client/cmd/update_supported.go | 75 ++ client/embed/embed.go | 2 +- client/internal/connect.go | 64 +- client/internal/debug/debug.go | 28 + client/internal/engine.go | 87 +- client/internal/engine_test.go | 31 +- client/internal/profilemanager/config.go | 22 +- client/internal/updatemanager/doc.go | 35 + .../updatemanager/downloader/downloader.go | 138 +++ .../downloader/downloader_test.go | 199 +++ .../installer/binary_nowindows.go | 7 + .../updatemanager/installer/binary_windows.go | 11 + .../internal/updatemanager/installer/doc.go | 111 ++ .../updatemanager/installer/installer.go | 50 + .../installer/installer_common.go | 293 +++++ .../installer/installer_log_darwin.go | 11 + .../installer/installer_log_windows.go | 12 + .../installer/installer_run_darwin.go | 238 ++++ .../installer/installer_run_windows.go | 213 ++++ .../internal/updatemanager/installer/log.go | 5 + .../installer/procattr_darwin.go | 15 + .../installer/procattr_windows.go | 14 + .../updatemanager/installer/repourl_dev.go | 7 + .../updatemanager/installer/repourl_prod.go | 7 + .../updatemanager/installer/result.go | 230 ++++ .../internal/updatemanager/installer/types.go | 14 + .../updatemanager/installer/types_darwin.go | 22 + .../updatemanager/installer/types_windows.go | 51 + client/internal/updatemanager/manager.go | 374 ++++++ client/internal/updatemanager/manager_test.go | 214 ++++ .../updatemanager/manager_unsupported.go | 39 + .../updatemanager/reposign/artifact.go | 302 +++++ .../updatemanager/reposign/artifact_test.go | 1080 +++++++++++++++++ .../updatemanager/reposign/certs/root-pub.pem | 6 + .../reposign/certsdev/root-pub.pem | 6 + client/internal/updatemanager/reposign/doc.go | 174 +++ .../updatemanager/reposign/embed_dev.go | 10 + .../updatemanager/reposign/embed_prod.go | 10 + client/internal/updatemanager/reposign/key.go | 171 +++ .../updatemanager/reposign/key_test.go | 636 ++++++++++ .../updatemanager/reposign/revocation.go | 229 ++++ .../updatemanager/reposign/revocation_test.go | 860 +++++++++++++ .../internal/updatemanager/reposign/root.go | 120 ++ .../updatemanager/reposign/root_test.go | 476 ++++++++ .../updatemanager/reposign/signature.go | 24 + .../updatemanager/reposign/signature_test.go | 277 +++++ .../internal/updatemanager/reposign/verify.go | 187 +++ .../updatemanager/reposign/verify_test.go | 528 ++++++++ client/internal/updatemanager/update.go | 11 + client/ios/NetBirdSDK/client.go | 2 +- client/netbird.wxs | 2 +- client/proto/daemon.pb.go | 227 +++- client/proto/daemon.proto | 11 + client/proto/daemon_grpc.pb.go | 36 + client/proto/generate.sh | 2 +- client/server/server.go | 20 +- client/server/server_test.go | 2 +- client/server/updateresult.go | 30 + client/ui/client_ui.go | 98 +- client/ui/event_handler.go | 26 +- client/ui/profile.go | 6 +- client/ui/quickactions.go | 2 +- client/ui/update.go | 140 +++ client/ui/update_notwindows.go | 7 + client/ui/update_windows.go | 44 + management/internals/server/server.go | 2 +- .../internals/shared/grpc/conversion.go | 11 +- management/server/account.go | 12 +- management/server/activity/codes.go | 11 +- .../handlers/accounts/accounts_handler.go | 106 +- .../accounts/accounts_handler_test.go | 28 + management/server/types/settings.go | 4 + shared/management/http/api/openapi.yml | 4 + shared/management/http/api/types.gen.go | 3 + shared/management/proto/management.pb.go | 1027 +++++++++------- shared/management/proto/management.proto | 12 + version/update.go | 40 +- version/update_test.go | 6 +- 87 files changed, 9720 insertions(+), 716 deletions(-) create mode 100644 client/cmd/signer/artifactkey.go create mode 100644 client/cmd/signer/artifactsign.go create mode 100644 client/cmd/signer/main.go create mode 100644 client/cmd/signer/revocation.go create mode 100644 client/cmd/signer/rootkey.go create mode 100644 client/cmd/update.go create mode 100644 client/cmd/update_supported.go create mode 100644 client/internal/updatemanager/doc.go create mode 100644 client/internal/updatemanager/downloader/downloader.go create mode 100644 client/internal/updatemanager/downloader/downloader_test.go create mode 100644 client/internal/updatemanager/installer/binary_nowindows.go create mode 100644 client/internal/updatemanager/installer/binary_windows.go create mode 100644 client/internal/updatemanager/installer/doc.go create mode 100644 client/internal/updatemanager/installer/installer.go create mode 100644 client/internal/updatemanager/installer/installer_common.go create mode 100644 client/internal/updatemanager/installer/installer_log_darwin.go create mode 100644 client/internal/updatemanager/installer/installer_log_windows.go create mode 100644 client/internal/updatemanager/installer/installer_run_darwin.go create mode 100644 client/internal/updatemanager/installer/installer_run_windows.go create mode 100644 client/internal/updatemanager/installer/log.go create mode 100644 client/internal/updatemanager/installer/procattr_darwin.go create mode 100644 client/internal/updatemanager/installer/procattr_windows.go create mode 100644 client/internal/updatemanager/installer/repourl_dev.go create mode 100644 client/internal/updatemanager/installer/repourl_prod.go create mode 100644 client/internal/updatemanager/installer/result.go create mode 100644 client/internal/updatemanager/installer/types.go create mode 100644 client/internal/updatemanager/installer/types_darwin.go create mode 100644 client/internal/updatemanager/installer/types_windows.go create mode 100644 client/internal/updatemanager/manager.go create mode 100644 client/internal/updatemanager/manager_test.go create mode 100644 client/internal/updatemanager/manager_unsupported.go create mode 100644 client/internal/updatemanager/reposign/artifact.go create mode 100644 client/internal/updatemanager/reposign/artifact_test.go create mode 100644 client/internal/updatemanager/reposign/certs/root-pub.pem create mode 100644 client/internal/updatemanager/reposign/certsdev/root-pub.pem create mode 100644 client/internal/updatemanager/reposign/doc.go create mode 100644 client/internal/updatemanager/reposign/embed_dev.go create mode 100644 client/internal/updatemanager/reposign/embed_prod.go create mode 100644 client/internal/updatemanager/reposign/key.go create mode 100644 client/internal/updatemanager/reposign/key_test.go create mode 100644 client/internal/updatemanager/reposign/revocation.go create mode 100644 client/internal/updatemanager/reposign/revocation_test.go create mode 100644 client/internal/updatemanager/reposign/root.go create mode 100644 client/internal/updatemanager/reposign/root_test.go create mode 100644 client/internal/updatemanager/reposign/signature.go create mode 100644 client/internal/updatemanager/reposign/signature_test.go create mode 100644 client/internal/updatemanager/reposign/verify.go create mode 100644 client/internal/updatemanager/reposign/verify_test.go create mode 100644 client/internal/updatemanager/update.go create mode 100644 client/server/updateresult.go create mode 100644 client/ui/update.go create mode 100644 client/ui/update_notwindows.go create mode 100644 client/ui/update_windows.go diff --git a/client/android/client.go b/client/android/client.go index 0d5474c4b..b16a23d64 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -122,7 +122,7 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } @@ -149,7 +149,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } diff --git a/client/cmd/root.go b/client/cmd/root.go index 9f2eb109c..30120c196 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -85,6 +85,9 @@ var ( // Execute executes the root command. func Execute() error { + if isUpdateBinary() { + return updateCmd.Execute() + } return rootCmd.Execute() } diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go new file mode 100644 index 000000000..5e656650b --- /dev/null +++ b/client/cmd/signer/artifactkey.go @@ -0,0 +1,176 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + bundlePubKeysRootPrivKeyFile string + bundlePubKeysPubKeyFiles []string + bundlePubKeysFile string + + createArtifactKeyRootPrivKeyFile string + createArtifactKeyPrivKeyFile string + createArtifactKeyPubKeyFile string + createArtifactKeyExpiration time.Duration +) + +var createArtifactKeyCmd = &cobra.Command{ + Use: "create-artifact-key", + Short: "Create a new artifact signing key", + Long: `Generate a new artifact signing key pair signed by the root private key. +The artifact key will be used to sign software artifacts/updates.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if createArtifactKeyExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil { + return fmt.Errorf("failed to create artifact key: %w", err) + } + return nil + }, +} + +var bundlePubKeysCmd = &cobra.Command{ + Use: "bundle-pub-keys", + Short: "Bundle multiple artifact public keys into a signed package", + Long: `Bundle one or more artifact public keys into a signed package using the root private key. +This command is typically used to distribute or authorize a set of valid artifact signing keys.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(bundlePubKeysPubKeyFiles) == 0 { + return fmt.Errorf("at least one --artifact-pub-key-file must be provided") + } + + if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil { + return fmt.Errorf("failed to bundle public keys: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createArtifactKeyCmd) + + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved") + createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)") + + if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(fmt.Errorf("mark expiration as required: %w", err)) + } + + rootCmd.AddCommand(bundlePubKeysCmd) + + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle") + bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)") + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved") + + if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil { + panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err)) + } +} + +func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error { + cmd.Println("Creating new artifact signing key...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration) + if err != nil { + return fmt.Errorf("generate artifact key: %w", err) + } + + if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err) + } + + if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err) + } + + signatureFile := artifactPubKeyFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Artifact key created successfully.\n") + cmd.Printf("%s\n", artifactKey.String()) + return nil +} + +func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error { + cmd.Println("📦 Bundling public keys into signed package...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles)) + for _, pubFile := range artifactPubKeyFiles { + pubPem, err := os.ReadFile(pubFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + pk, err := reposign.ParseArtifactPubKey(pubPem) + if err != nil { + return fmt.Errorf("failed to parse artifact key: %w", err) + } + publicKeys = append(publicKeys, pk) + } + + parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys) + if err != nil { + return fmt.Errorf("bundle artifact keys: %w", err) + } + + if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil { + return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err) + } + + signatureFile := bundlePubKeysFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles)) + return nil +} diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go new file mode 100644 index 000000000..881be9367 --- /dev/null +++ b/client/cmd/signer/artifactsign.go @@ -0,0 +1,276 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY" +) + +var ( + signArtifactPrivKeyFile string + signArtifactArtifactFile string + + verifyArtifactPubKeyFile string + verifyArtifactFile string + verifyArtifactSignatureFile string + + verifyArtifactKeyPubKeyFile string + verifyArtifactKeyRootPubKeyFile string + verifyArtifactKeySignatureFile string + verifyArtifactKeyRevocationFile string +) + +var signArtifactCmd = &cobra.Command{ + Use: "sign-artifact", + Short: "Sign an artifact using an artifact private key", + Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key. +This command produces a detached signature that can be verified using the corresponding artifact public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil { + return fmt.Errorf("failed to sign artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactCmd = &cobra.Command{ + Use: "verify-artifact", + Short: "Verify an artifact signature using an artifact public key", + Long: `Verify a software artifact signature using the artifact's public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil { + return fmt.Errorf("failed to verify artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactKeyCmd = &cobra.Command{ + Use: "verify-artifact-key", + Short: "Verify an artifact public key was signed by a root key", + Long: `Verify that an artifact public key (or bundle) was properly signed by a root key. +This validates the chain of trust from the root key to the artifact key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil { + return fmt.Errorf("failed to verify artifact key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(signArtifactCmd) + rootCmd.AddCommand(verifyArtifactCmd) + rootCmd.AddCommand(verifyArtifactKeyCmd) + + signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey)) + signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed") + + // artifact-file is required, but artifact-key-file can come from env var + if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + + verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file") + + if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } + + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)") + + if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil { + panic(fmt.Errorf("mark root-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } +} + +func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error { + cmd.Println("🖋️ Signing artifact...") + + // Load private key from env var or file + var privKeyPEM []byte + var err error + + if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" { + // Use key from environment variable + privKeyPEM = []byte(envKey) + } else if privKeyFile != "" { + // Fall back to file + privKeyPEM, err = os.ReadFile(privKeyFile) + if err != nil { + return fmt.Errorf("read private key file: %w", err) + } + } else { + return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey) + } + + privateKey, err := reposign.ParseArtifactKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact private key: %w", err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + signature, err := reposign.SignData(privateKey, artifactData) + if err != nil { + return fmt.Errorf("sign artifact: %w", err) + } + + sigFile := artifactFile + ".sig" + if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", sigFile, err) + } + + cmd.Printf("✅ Artifact signed successfully.\n") + cmd.Printf("Signature file: %s\n", sigFile) + return nil +} + +func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error { + cmd.Println("🔍 Verifying artifact...") + + // Read artifact public key + pubKeyPEM, err := os.ReadFile(pubKeyFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact public key: %w", err) + } + + // Read artifact data + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate artifact + if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil { + return fmt.Errorf("artifact verification failed: %w", err) + } + + cmd.Println("✅ Artifact signature is valid") + cmd.Printf("Artifact: %s\n", artifactFile) + cmd.Printf("Signed by key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + return nil +} + +func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error { + cmd.Println("🔍 Verifying artifact key...") + + // Read artifact key data + artifactKeyData, err := os.ReadFile(artifactKeyFile) + if err != nil { + return fmt.Errorf("read artifact key file: %w", err) + } + + // Read root public key(s) + rootKeyData, err := os.ReadFile(rootKeyFile) + if err != nil { + return fmt.Errorf("read root key file: %w", err) + } + + rootPublicKeys, err := parseRootPublicKeys(rootKeyData) + if err != nil { + return fmt.Errorf("failed to parse root public key(s): %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Read optional revocation list + var revocationList *reposign.RevocationList + if revocationFile != "" { + revData, err := os.ReadFile(revocationFile) + if err != nil { + return fmt.Errorf("read revocation file: %w", err) + } + + revocationList, err = reposign.ParseRevocationList(revData) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + } + + // Validate artifact key(s) + validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList) + if err != nil { + return fmt.Errorf("artifact key verification failed: %w", err) + } + + cmd.Println("✅ Artifact key(s) verified successfully") + cmd.Printf("Signed by root key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys)) + for i, key := range validKeys { + cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID) + cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST")) + if !key.Metadata.ExpiresAt.IsZero() { + cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST")) + } else { + cmd.Printf(" Expires: Never\n") + } + } + return nil +} + +// parseRootPublicKeys parses a root public key from PEM data +func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) { + key, err := reposign.ParseRootPublicKey(data) + if err != nil { + return nil, err + } + return []reposign.PublicKey{key}, nil +} diff --git a/client/cmd/signer/main.go b/client/cmd/signer/main.go new file mode 100644 index 000000000..407093d07 --- /dev/null +++ b/client/cmd/signer/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "signer", + Short: "A CLI tool for managing cryptographic keys and artifacts", + Long: `signer is a command-line tool that helps you manage +root keys, artifact keys, and revocation lists securely.`, +} + +func main() { + if err := rootCmd.Execute(); err != nil { + rootCmd.Println(err) + os.Exit(1) + } +} diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go new file mode 100644 index 000000000..1d84b65c3 --- /dev/null +++ b/client/cmd/signer/revocation.go @@ -0,0 +1,220 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year +) + +var ( + keyID string + revocationListFile string + privateRootKeyFile string + publicRootKeyFile string + signatureFile string + expirationDuration time.Duration +) + +var createRevocationListCmd = &cobra.Command{ + Use: "create-revocation-list", + Short: "Create a new revocation list signed by the private root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile) + }, +} + +var extendRevocationListCmd = &cobra.Command{ + Use: "extend-revocation-list", + Short: "Extend an existing revocation list with a given key ID", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile) + }, +} + +var verifyRevocationListCmd = &cobra.Command{ + Use: "verify-revocation-list", + Short: "Verify a revocation list signature using the public root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile) + }, +} + +func init() { + rootCmd.AddCommand(createRevocationListCmd) + rootCmd.AddCommand(extendRevocationListCmd) + rootCmd.AddCommand(verifyRevocationListCmd) + + createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for") + extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file") + verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file") + verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file") + if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil { + panic(err) + } +} + +func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration) + if err != nil { + return fmt.Errorf("failed to create revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list created successfully") + return nil +} + +func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + rl, err := reposign.ParseRevocationList(rlBytes) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + + kid, err := reposign.ParseKeyID(keyID) + if err != nil { + return fmt.Errorf("invalid key ID: %w", err) + } + + newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration) + if err != nil { + return fmt.Errorf("failed to extend revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list extended successfully") + return nil +} + +func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error { + // Read revocation list file + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + // Read signature file + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("failed to read signature file: %w", err) + } + + // Read public root key file + pubKeyPEM, err := os.ReadFile(publicRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read public root key file: %w", err) + } + + // Parse public root key + publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse public root key: %w", err) + } + + // Parse signature + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate revocation list + rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature) + if err != nil { + return fmt.Errorf("failed to validate revocation list: %w", err) + } + + // Display results + cmd.Println("✅ Revocation list signature is valid") + cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339)) + cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339)) + cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked)) + + if len(rl.Revoked) > 0 { + cmd.Println("\nRevoked Keys:") + for keyID, revokedTime := range rl.Revoked { + cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339)) + } + } + + return nil +} + +func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error { + if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil { + return fmt.Errorf("failed to write revocation list file: %w", err) + } + if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil { + return fmt.Errorf("failed to write signature file: %w", err) + } + return nil +} diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go new file mode 100644 index 000000000..78ac36b41 --- /dev/null +++ b/client/cmd/signer/rootkey.go @@ -0,0 +1,74 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + privKeyFile string + pubKeyFile string + rootExpiration time.Duration +) + +var createRootKeyCmd = &cobra.Command{ + Use: "create-root-key", + Short: "Create a new root key pair", + Long: `Create a new root key pair and specify an expiration time for it.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + // Validate expiration + if rootExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + // Run main logic + if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil { + return fmt.Errorf("failed to generate root key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createRootKeyCmd) + createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file") + createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file") + createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)") + + if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(err) + } +} + +func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error { + rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration) + if err != nil { + return fmt.Errorf("generate root key: %w", err) + } + + // Write private key + if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", privKeyFile, err) + } + + // Write public key + if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err) + } + + cmd.Printf("%s\n\n", rk.String()) + cmd.Printf("✅ Root key pair generated successfully.\n") + return nil +} diff --git a/client/cmd/up.go b/client/cmd/up.go index 140ba2cb2..9efc2e60d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r) + connectClient := internal.NewConnectClient(ctx, config, r, false) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil) diff --git a/client/cmd/update.go b/client/cmd/update.go new file mode 100644 index 000000000..dc49b02c3 --- /dev/null +++ b/client/cmd/update.go @@ -0,0 +1,13 @@ +//go:build !windows && !darwin + +package cmd + +import ( + "github.com/spf13/cobra" +) + +var updateCmd *cobra.Command + +func isUpdateBinary() bool { + return false +} diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go new file mode 100644 index 000000000..977875093 --- /dev/null +++ b/client/cmd/update_supported.go @@ -0,0 +1,75 @@ +//go:build windows || darwin + +package cmd + +import ( + "context" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/util" +) + +var ( + updateCmd = &cobra.Command{ + Use: "update", + Short: "Update the NetBird client application", + RunE: updateFunc, + } + + tempDirFlag string + installerFile string + serviceDirFlag string + dryRunFlag bool +) + +func init() { + updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir") + updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file") + updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory") + updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes") +} + +// isUpdateBinary checks if the current executable is named "update" or "update.exe" +func isUpdateBinary() bool { + // Remove extension for cross-platform compatibility + execPath, err := os.Executable() + if err != nil { + return false + } + baseName := filepath.Base(execPath) + name := strings.TrimSuffix(baseName, filepath.Ext(baseName)) + + return name == installer.UpdaterBinaryNameWithoutExtension() +} + +func updateFunc(cmd *cobra.Command, args []string) error { + if err := setupLogToFile(tempDirFlag); err != nil { + return err + } + + log.Infof("updater started: %s", serviceDirFlag) + updater := installer.NewWithDir(tempDirFlag) + if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil { + log.Errorf("failed to update application: %v", err) + return err + } + return nil +} + +func setupLogToFile(dir string) error { + logFile := filepath.Join(dir, installer.LogFile) + + if _, err := os.Stat(logFile); err == nil { + if err := os.Remove(logFile); err != nil { + log.Errorf("failed to remove existing log file: %v\n", err) + } + } + + return util.InitLog(logLevel, util.LogConsole, logFile) +} diff --git a/client/embed/embed.go b/client/embed/embed.go index 3090ca6a2..353c5438f 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error { } recorder := peer.NewRecorder(c.config.ManagementURL.String()) - client := internal.NewConnectClient(ctx, c.config, recorder) + client := internal.NewConnectClient(ctx, c.config, recorder, false) // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available diff --git a/client/internal/connect.go b/client/internal/connect.go index e9d422a28..79ec2f907 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -24,10 +24,14 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/shared/management/client" mgmProto "github.com/netbirdio/netbird/shared/management/proto" @@ -39,11 +43,13 @@ import ( ) type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - engine *Engine - engineMutex sync.Mutex + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status + doInitialAutoUpdate bool + + engine *Engine + engineMutex sync.Mutex persistSyncResponse bool } @@ -52,13 +58,15 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, + doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + doInitialAutoUpdate: doInitalAutoUpdate, + engineMutex: sync.Mutex{}, } } @@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return err } + sm := profilemanager.NewServiceManager("") + + path := sm.GetStatePath() + if runtime.GOOS == "ios" || runtime.GOOS == "android" { + if !fileExists(mobileDependency.StateFilePath) { + err := createFile(mobileDependency.StateFilePath) + if err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + + path = mobileDependency.StateFilePath + } + stateManager := statemanager.New(path) + stateManager.RegisterState(&sshconfig.ShutdownState{}) + + updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) + if err == nil { + updateManager.CheckUpdateSuccess(c.ctx) + + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) + } + } + defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle @@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) + engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine c.engineMutex.Unlock() @@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } + if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { + // AutoUpdate will be true when the user click on "Connect" menu on the UI + if c.doInitialAutoUpdate { + log.Infof("start engine by ui, run auto-update check") + c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) + c.doInitialAutoUpdate = false + } + } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 3c201ecfc..01a0377a5 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -362,6 +363,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add systemd logs: %v", err) } + if err := g.addUpdateLogs(); err != nil { + log.Errorf("failed to add updater logs: %v", err) + } + return nil } @@ -650,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error { return nil } +func (g *BundleGenerator) addUpdateLogs() error { + inst := installer.New() + logFiles := inst.LogFiles() + if len(logFiles) == 0 { + return nil + } + + log.Infof("adding updater logs") + for _, logFile := range logFiles { + data, err := os.ReadFile(logFile) + if err != nil { + log.Warnf("failed to read update log file %s: %v", logFile, err) + continue + } + + baseName := filepath.Base(logFile) + if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil { + return fmt.Errorf("add update log file %s to zip: %w", baseName, err) + } + } + return nil +} + func (g *BundleGenerator) addCorruptedStateFiles() error { sm := profilemanager.NewServiceManager("") pattern := sm.GetStatePath() diff --git a/client/internal/engine.go b/client/internal/engine.go index ff1cec19a..084bafdbd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -42,14 +42,13 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" - "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager" cProto "github.com/netbirdio/netbird/client/proto" - sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" @@ -73,6 +72,7 @@ const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms connInitLimit = 200 + disableAutoUpdate = "disabled" ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -201,6 +201,9 @@ type Engine struct { connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + // auto-update + updateManager *updatemanager.Manager + // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor @@ -221,17 +224,7 @@ type localIpUpdater interface { } // NewEngine creates a new Connection Engine with probes attached -func NewEngine( - clientCtx context.Context, - clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, - config *EngineConfig, - mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, -) *Engine { +func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -247,28 +240,12 @@ func NewEngine( TURNs: []*stun.URI{}, networkSerial: 0, statusRecorder: statusRecorder, + stateManager: stateManager, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } - sm := profilemanager.NewServiceManager("") - - path := sm.GetStatePath() - if runtime.GOOS == "ios" || runtime.GOOS == "android" { - if !fileExists(mobileDep.StateFilePath) { - err := createFile(mobileDep.StateFilePath) - if err != nil { - log.Errorf("failed to create state file: %v", err) - // we are not exiting as we can run without the state manager - } - } - - path = mobileDep.StateFilePath - } - engine.stateManager = statemanager.New(path) - engine.stateManager.RegisterState(&sshconfig.ShutdownState{}) - log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) return engine } @@ -308,6 +285,10 @@ func (e *Engine) Stop() error { e.srWatcher.Close() } + if e.updateManager != nil { + e.updateManager.Stop() + } + log.Info("cleaning up status recorder states") e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) @@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } +func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + e.handleAutoUpdateVersion(autoUpdateSettings, true) +} + func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { + if autoUpdateSettings == nil { + return + } + + disabled := autoUpdateSettings.Version == disableAutoUpdate + + // Stop and cleanup if disabled + if e.updateManager != nil && disabled { + log.Infof("auto-update is disabled, stopping update manager") + e.updateManager.Stop() + e.updateManager = nil + return + } + + // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup + if !autoUpdateSettings.AlwaysUpdate && !initialCheck { + log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") + return + } + + // Start manager if needed + if e.updateManager == nil { + log.Infof("starting auto-update manager") + updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) + if err != nil { + return + } + e.updateManager = updateManager + e.updateManager.Start(e.ctx) + } + log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version) +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return e.ctx.Err() } + if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + } + if update.GetNetbirdConfig() != nil { wCfg := update.GetNetbirdConfig() err := e.updateTURNs(wCfg.GetTurns()) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 5ab21e3e1..26ea6f8c2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) { MobileDependency{}, peer.NewRecorder("https://mgm"), nil, + nil, ) engine.dnsServer = &dns.MockServer{ @@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - MTU: iface.DefaultMTU, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { @@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) @@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil e.ctx = ctx return e, err } diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 8f467a214..84ee73902 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "os/user" "path/filepath" "reflect" "runtime" @@ -165,19 +166,26 @@ func getConfigDir() (string, error) { if ConfigDirOverride != "" { return ConfigDirOverride, nil } - configDir, err := os.UserConfigDir() + + base, err := baseConfigDir() if err != nil { return "", err } - configDir = filepath.Join(configDir, "netbird") - if _, err := os.Stat(configDir); os.IsNotExist(err) { - if err := os.MkdirAll(configDir, 0755); err != nil { - return "", err + configDir := filepath.Join(base, "netbird") + if err := os.MkdirAll(configDir, 0o755); err != nil { + return "", err + } + return configDir, nil +} + +func baseConfigDir() (string, error) { + if runtime.GOOS == "darwin" { + if u, err := user.Current(); err == nil && u.HomeDir != "" { + return filepath.Join(u.HomeDir, "Library", "Application Support"), nil } } - - return configDir, nil + return os.UserConfigDir() } func getConfigDirForUser(username string) (string, error) { diff --git a/client/internal/updatemanager/doc.go b/client/internal/updatemanager/doc.go new file mode 100644 index 000000000..54d1bdeab --- /dev/null +++ b/client/internal/updatemanager/doc.go @@ -0,0 +1,35 @@ +// Package updatemanager provides automatic update management for the NetBird client. +// It monitors for new versions, handles update triggers from management server directives, +// and orchestrates the download and installation of client updates. +// +// # Overview +// +// The update manager operates as a background service that continuously monitors for +// available updates and automatically initiates the update process when conditions are met. +// It integrates with the installer package to perform the actual installation. +// +// # Update Flow +// +// The complete update process follows these steps: +// +// 1. Manager receives update directive via SetVersion() or detects new version +// 2. Manager validates update should proceed (version comparison, rate limiting) +// 3. Manager publishes "updating" event to status recorder +// 4. Manager persists UpdateState to track update attempt +// 5. Manager downloads installer file (.msi or .exe) to temporary directory +// 6. Manager triggers installation via installer.RunInstallation() +// 7. Installer package handles the actual installation process +// 8. On next startup, CheckUpdateSuccess() verifies update completion +// 9. Manager publishes success/failure event to status recorder +// 10. Manager cleans up UpdateState +// +// # State Management +// +// Update state is persisted across restarts to track update attempts: +// +// - PreUpdateVersion: Version before update attempt +// - TargetVersion: Version attempting to update to +// +// This enables verification of successful updates and appropriate user notification +// after the client restarts with the new version. +package updatemanager diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updatemanager/downloader/downloader.go new file mode 100644 index 000000000..2ac36efed --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader.go @@ -0,0 +1,138 @@ +package downloader + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" +) + +const ( + userAgent = "NetBird agent installer/%s" + DefaultRetryDelay = 3 * time.Second +) + +func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error { + log.Debugf("starting download from %s", url) + + out, err := os.Create(dstFile) + if err != nil { + return fmt.Errorf("failed to create destination file %q: %w", dstFile, err) + } + defer func() { + if cerr := out.Close(); cerr != nil { + log.Warnf("error closing file %q: %v", dstFile, cerr) + } + }() + + // First attempt + err = downloadToFileOnce(ctx, url, out) + if err == nil { + log.Infof("successfully downloaded file to %s", dstFile) + return nil + } + + // If retryDelay is 0, don't retry + if retryDelay == 0 { + return err + } + + log.Warnf("download failed, retrying after %v: %v", retryDelay, err) + + // Sleep before retry + if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil { + return fmt.Errorf("download cancelled during retry delay: %w", sleepErr) + } + + // Truncate file before retry + if err := out.Truncate(0); err != nil { + return fmt.Errorf("failed to truncate file on retry: %w", err) + } + if _, err := out.Seek(0, 0); err != nil { + return fmt.Errorf("failed to seek to beginning of file: %w", err) + } + + // Second attempt + if err := downloadToFileOnce(ctx, url, out); err != nil { + return fmt.Errorf("download failed after retry: %w", err) + } + + log.Infof("successfully downloaded file to %s", dstFile) + return nil +} + +func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, limit)) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return data, nil +} + +func downloadToFileOnce(ctx context.Context, url string, out *os.File) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + if _, err := io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("failed to write response body to file: %w", err) + } + + return nil +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + select { + case <-time.After(duration): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updatemanager/downloader/downloader_test.go new file mode 100644 index 000000000..045db3a2d --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader_test.go @@ -0,0 +1,199 @@ +package downloader + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" +) + +const ( + retryDelay = 100 * time.Millisecond +) + +func TestDownloadToFile_Success(t *testing.T) { + // Create a test server that responds successfully + content := "test file content" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file + err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } +} + +func TestDownloadToFile_SuccessAfterRetry(t *testing.T) { + content := "test file content after retry" + var attemptCount atomic.Int32 + + // Create a test server that fails on first attempt, succeeds on second + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := attemptCount.Add(1) + if attempt == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should succeed after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil { + t.Fatalf("expected no error after retry, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } + + // Verify it took 2 attempts + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_FailsAfterRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should fail after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil { + t.Fatal("expected error after retry, got nil") + } + + // Verify it tried 2 times + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Create a context that will be cancelled during retry delay + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after a short delay (during the retry sleep) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + // Download the file (should fail due to context cancellation during retry) + err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile) + if err == nil { + t.Fatal("expected error due to context cancellation, got nil") + } + + // Should have only made 1 attempt (cancelled during retry delay) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_InvalidURL(t *testing.T) { + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile) + if err == nil { + t.Fatal("expected error for invalid URL, got nil") + } +} + +func TestDownloadToFile_InvalidDestination(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("test")) + })) + defer server.Close() + + // Use an invalid destination path + err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt") + if err == nil { + t.Fatal("expected error for invalid destination, got nil") + } +} + +func TestDownloadToFile_NoRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file with retryDelay = 0 (should not retry) + if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil { + t.Fatal("expected error, got nil") + } + + // Verify it only made 1 attempt (no retry) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updatemanager/installer/binary_nowindows.go new file mode 100644 index 000000000..19f3bef83 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_nowindows.go @@ -0,0 +1,7 @@ +//go:build !windows + +package installer + +func UpdaterBinaryNameWithoutExtension() string { + return updaterBinary +} diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updatemanager/installer/binary_windows.go new file mode 100644 index 000000000..4c66391c2 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_windows.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" + "strings" +) + +func UpdaterBinaryNameWithoutExtension() string { + ext := filepath.Ext(updaterBinary) + return strings.TrimSuffix(updaterBinary, ext) +} diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updatemanager/installer/doc.go new file mode 100644 index 000000000..0a60454bb --- /dev/null +++ b/client/internal/updatemanager/installer/doc.go @@ -0,0 +1,111 @@ +// Package installer provides functionality for managing NetBird application +// updates and installations across Windows, macOS. It handles +// the complete update lifecycle including artifact download, cryptographic verification, +// installation execution, process management, and result reporting. +// +// # Architecture +// +// The installer package uses a two-process architecture to enable self-updates: +// +// 1. Service Process: The main NetBird daemon process that initiates updates +// 2. Updater Process: A detached child process that performs the actual installation +// +// This separation is critical because: +// - The service binary cannot update itself while running +// - The installer (EXE/MSI/PKG) will terminate the service during installation +// - The updater process survives service termination and restarts it after installation +// - Results can be communicated back to the service after it restarts +// +// # Update Flow +// +// Service Process (RunInstallation): +// +// 1. Validates target version format (semver) +// 2. Determines installer type (EXE, MSI, PKG, or Homebrew) +// 3. Downloads installer file from GitHub releases (if applicable) +// 4. Verifies installer signature using reposign package (cryptographic verification in service process before +// launching updater) +// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows) +// 6. Launches updater process with detached mode: +// - --temp-dir: Temporary directory path +// - --service-dir: Service installation directory +// - --installer-file: Path to downloaded installer (if applicable) +// - --dry-run: Optional flag to test without actually installing +// 7. Service process continues running (will be terminated by installer later) +// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion +// +// Updater Process (Setup): +// +// 1. Receives parameters from service via command-line arguments +// 2. Runs installer with appropriate silent/quiet flags: +// - Windows EXE: installer.exe /S +// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log +// - macOS PKG: installer -pkg installer.pkg -target / +// - macOS Homebrew: brew upgrade netbirdio/tap/netbird +// 3. Installer terminates daemon and UI processes +// 4. Installer replaces binaries with new version +// 5. Updater waits for installer to complete +// 6. Updater restarts daemon: +// - Windows: netbird.exe service start +// - macOS/Linux: netbird service start +// 7. Updater restarts UI: +// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser +// - macOS: Uses launchctl asuser to launch NetBird.app for console user +// - Linux: Not implemented (UI typically auto-starts) +// 8. Updater writes result.json with success/error status +// 9. Updater process exits +// +// # Result Communication +// +// The ResultHandler (result.go) manages communication between updater and service: +// +// Result Structure: +// +// type Result struct { +// Success bool // true if installation succeeded +// Error string // error message if Success is false +// ExecutedAt time.Time // when installation completed +// } +// +// Result files are automatically cleaned up after being read. +// +// # File Locations +// +// Temporary Directory (platform-specific): +// +// Windows: +// - Path: %ProgramData%\Netbird\tmp-install +// - Example: C:\ProgramData\Netbird\tmp-install +// +// macOS: +// - Path: /var/lib/netbird/tmp-install +// - Requires root permissions +// +// Files created during installation: +// +// tmp-install/ +// installer.log +// updater[.exe] # Copy of service binary +// netbird_installer_*.[exe|msi|pkg] # Downloaded installer +// result.json # Installation result +// msi.log # MSI verbose log (Windows MSI only) +// +// # API Reference +// +// # Cleanup +// +// CleanUpInstallerFiles() removes temporary files after successful installation: +// - Downloaded installer files (*.exe, *.msi, *.pkg) +// - Updater binary copy +// - Does NOT remove result.json (cleaned by ResultHandler after read) +// - Does NOT remove msi.log (kept for debugging) +// +// # Dry-Run Mode +// +// Dry-run mode allows testing the update process without actually installing: +// +// Enable via environment variable: +// +// export NB_AUTO_UPDATE_DRY_RUN=true +// netbird service install-update 0.29.0 +package installer diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updatemanager/installer/installer.go new file mode 100644 index 000000000..caf5873f8 --- /dev/null +++ b/client/internal/updatemanager/installer/installer.go @@ -0,0 +1,50 @@ +//go:build !windows && !darwin + +package installer + +import ( + "context" + "fmt" +) + +const ( + updaterBinary = "updater" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{} +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +func (u *Installer) TempDir() string { + return "" +} + +func (c *Installer) LogFiles() []string { + return []string{} +} + +func (u *Installer) CleanUpInstallerFiles() error { + return nil +} + +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error { + return fmt.Errorf("unsupported platform") +} + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) { + return fmt.Errorf("unsupported platform") +} diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updatemanager/installer/installer_common.go new file mode 100644 index 000000000..03378d55f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_common.go @@ -0,0 +1,293 @@ +//go:build windows || darwin + +package installer + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + + "github.com/hashicorp/go-multierror" + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{ + tempDir: defaultTempDir, + } +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +// RunInstallation starts the updater process to run the installation +// This will run by the original service process +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) { + resultHandler := NewResultHandler(u.tempDir) + + defer func() { + if err != nil { + if writeErr := resultHandler.WriteErr(err); writeErr != nil { + log.Errorf("failed to write error result: %v", writeErr) + } + } + }() + + if err := validateTargetVersion(targetVersion); err != nil { + return err + } + + if err := u.mkTempDir(); err != nil { + return err + } + + var installerFile string + // Download files only when not using any third-party store + if installerType := TypeOfInstaller(ctx); installerType.Downloadable() { + log.Infof("download installer") + var err error + installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion) + if err != nil { + log.Errorf("failed to download installer: %v", err) + return err + } + + artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL) + if err != nil { + log.Errorf("failed to create artifact verify: %v", err) + return err + } + + if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil { + log.Errorf("artifact verification error: %v", err) + return err + } + } + + log.Infof("running installer") + updaterPath, err := u.copyUpdater() + if err != nil { + return err + } + + // the directory where the service has been installed + workspace, err := getServiceDir() + if err != nil { + return err + } + + args := []string{ + "--temp-dir", u.tempDir, + "--service-dir", workspace, + } + + if isDryRunEnabled() { + args = append(args, "--dry-run=true") + } + + if installerFile != "" { + args = append(args, "--installer-file", installerFile) + } + + updateCmd := exec.Command(updaterPath, args...) + log.Infof("starting updater process: %s", updateCmd.String()) + + // Configure the updater to run in a separate session/process group + // so it survives the parent daemon being stopped + setUpdaterProcAttr(updateCmd) + + // Start the updater process asynchronously + if err := updateCmd.Start(); err != nil { + return err + } + + pid := updateCmd.Process.Pid + log.Infof("updater started with PID %d", pid) + + // Release the process so the OS can fully detach it + if err := updateCmd.Process.Release(); err != nil { + log.Warnf("failed to release updater process: %v", err) + } + + return nil +} + +// CleanUpInstallerFiles +// - the installer file (pkg, exe, msi) +// - the selfcopy updater.exe +func (u *Installer) CleanUpInstallerFiles() error { + // Check if tempDir exists + info, err := os.Stat(u.tempDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + if !info.IsDir() { + return nil + } + + var merr *multierror.Error + + if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) { + merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err)) + } + + entries, err := os.ReadDir(u.tempDir) + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + for _, ext := range binaryExtensions { + if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) { + if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil { + merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err)) + } + break + } + } + } + + return merr.ErrorOrNil() +} + +func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) { + fileURL := urlWithVersionArch(installerType, targetVersion) + + // Clean up temp directory on error + var success bool + defer func() { + if !success { + if err := os.RemoveAll(u.tempDir); err != nil { + log.Errorf("error cleaning up temporary directory: %v", err) + } + } + }() + + fileName := path.Base(fileURL) + if fileName == "." || fileName == "/" || fileName == "" { + return "", fmt.Errorf("invalid file URL: %s", fileURL) + } + + outputFilePath := filepath.Join(u.tempDir, fileName) + if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil { + return "", err + } + + success = true + return outputFilePath, nil +} + +func (u *Installer) TempDir() string { + return u.tempDir +} + +func (u *Installer) mkTempDir() error { + if err := os.MkdirAll(u.tempDir, 0o755); err != nil { + log.Debugf("failed to create tempdir: %s", u.tempDir) + return err + } + return nil +} + +func (u *Installer) copyUpdater() (string, error) { + src, err := getServiceBinary() + if err != nil { + return "", fmt.Errorf("failed to get updater binary: %w", err) + } + + dst := filepath.Join(u.tempDir, updaterBinary) + if err := copyFile(src, dst); err != nil { + return "", fmt.Errorf("failed to copy updater binary: %w", err) + } + + if err := os.Chmod(dst, 0o755); err != nil { + return "", fmt.Errorf("failed to set permissions: %w", err) + } + + return dst, nil +} + +func validateTargetVersion(targetVersion string) error { + if targetVersion == "" { + return fmt.Errorf("target version cannot be empty") + } + + _, err := goversion.NewVersion(targetVersion) + if err != nil { + return fmt.Errorf("invalid target version %q: %w", targetVersion, err) + } + + return nil +} + +func copyFile(src, dst string) error { + log.Infof("copying %s to %s", src, dst) + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + defer func() { + if err := in.Close(); err != nil { + log.Warnf("failed to close source file: %v", err) + } + }() + + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("create destination: %w", err) + } + defer func() { + if err := out.Close(); err != nil { + log.Warnf("failed to close destination file: %v", err) + } + }() + + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy: %w", err) + } + + return nil +} + +func getServiceDir() (string, error) { + exePath, err := os.Executable() + if err != nil { + return "", err + } + return filepath.Dir(exePath), nil +} + +func getServiceBinary() (string, error) { + return os.Executable() +} + +func isDryRunEnabled() bool { + return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true") +} diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updatemanager/installer/installer_log_darwin.go new file mode 100644 index 000000000..50dd5d197 --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_darwin.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updatemanager/installer/installer_log_windows.go new file mode 100644 index 000000000..96e4cfd1f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_windows.go @@ -0,0 +1,12 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, msiLogFile), + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updatemanager/installer/installer_run_darwin.go new file mode 100644 index 000000000..462e2c227 --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_darwin.go @@ -0,0 +1,238 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + daemonName = "netbird" + updaterBinary = "updater" + uiBinary = "/Applications/NetBird.app" + + defaultTempDir = "/var/lib/netbird/tmp-install" + + pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg" +) + +var ( + binaryExtensions = []string{"pkg"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + + // skip service restart if dry-run mode is enabled + if dryRun { + return + } + + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + }() + + if dryRun { + time.Sleep(7 * time.Second) + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + switch TypeOfInstaller(ctx) { + case TypePKG: + resultErr = u.installPkgFile(ctx, installerFile) + case TypeHomebrew: + resultErr = u.updateHomeBrew(ctx) + } + + return resultErr +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Warnf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser() error { + log.Infof("starting netbird-ui: %s", uiBinary) + + // Get the current console user + cmd := exec.Command("stat", "-f", "%Su", "/dev/console") + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to get console user: %w", err) + } + + username := strings.TrimSpace(string(output)) + if username == "" || username == "root" { + return fmt.Errorf("no active user session found") + } + + log.Infof("starting UI for user: %s", username) + + // Get user's UID + userInfo, err := user.Lookup(username) + if err != nil { + return fmt.Errorf("failed to lookup user %s: %w", username, err) + } + + // Start the UI process as the console user using launchctl + // This ensures the app runs in the user's context with proper GUI access + launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary) + log.Infof("launchCmd: %s", launchCmd.String()) + // Set the user's home directory for proper macOS app behavior + launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir) + log.Infof("set HOME environment variable: %s", userInfo.HomeDir) + + if err := launchCmd.Start(); err != nil { + return fmt.Errorf("failed to start UI process: %w", err) + } + + // Release the process so it can run independently + if err := launchCmd.Process.Release(); err != nil { + log.Warnf("failed to release UI process: %v", err) + } + + log.Infof("netbird-ui started successfully for user %s", username) + return nil +} + +func (u *Installer) installPkgFile(ctx context.Context, path string) error { + log.Infof("installing pkg file: %s", path) + + // Kill any existing UI processes before installation + // This ensures the postinstall script's "open $APP" will start the new version + u.killUI() + + volume := "/" + + cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume) + if err := cmd.Start(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("installer started with PID %d", cmd.Process.Pid) + if err := cmd.Wait(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("pkg file installed successfully") + return nil +} + +func (u *Installer) updateHomeBrew(ctx context.Context) error { + log.Infof("updating homebrew") + + // Kill any existing UI processes before upgrade + // This ensures the new version will be started after upgrade + u.killUI() + + // Homebrew must be run as a non-root user + // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory + // Check both Apple Silicon and Intel Mac paths + brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath := "/opt/homebrew/bin/brew" + if _, err := os.Stat(brewTapPath); os.IsNotExist(err) { + // Try Intel Mac path + brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath = "/usr/local/bin/brew" + } + + fileInfo, err := os.Stat(brewTapPath) + if err != nil { + return fmt.Errorf("error getting homebrew installation path info: %w", err) + } + + fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys()) + } + + // Get username from UID + brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid)) + if err != nil { + return fmt.Errorf("error looking up brew installer user: %w", err) + } + userName := brewUser.Username + // Get user HOME, required for brew to run correctly + // https://github.com/Homebrew/brew/issues/15833 + homeDir := brewUser.HomeDir + + // Check if netbird-ui is installed (must run as the brew user, not root) + checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui") + checkUICmd.Env = append(os.Environ(), "HOME="+homeDir) + uiInstalled := checkUICmd.Run() == nil + + // Homebrew does not support installing specific versions + // Thus it will always update to latest and ignore targetVersion + upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"} + if uiInstalled { + upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui") + } + + cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...) + cmd.Env = append(os.Environ(), "HOME="+homeDir) + + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output)) + } + + log.Infof("homebrew updated successfully") + return nil +} + +func (u *Installer) killUI() { + log.Infof("killing existing netbird-ui processes") + cmd := exec.Command("pkill", "-x", "netbird-ui") + if output, err := cmd.CombinedOutput(); err != nil { + // pkill returns exit code 1 if no processes matched, which is fine + log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output)) + } else { + log.Infof("netbird-ui processes killed") + } +} + +func urlWithVersionArch(_ Type, version string) string { + url := strings.ReplaceAll(pkgDownloadURL, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updatemanager/installer/installer_run_windows.go new file mode 100644 index 000000000..353cd885d --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_windows.go @@ -0,0 +1,213 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + daemonName = "netbird.exe" + uiName = "netbird-ui.exe" + updaterBinary = "updater.exe" + + msiLogFile = "msi.log" + + msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi" + exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe" +) + +var ( + defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install") + + // for the cleanup + binaryExtensions = []string{"msi", "exe"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(daemonFolder); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + }() + + if dryRun { + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + installerType, err := typeByFileExtension(installerFile) + if err != nil { + log.Debugf("%v", err) + resultErr = err + return + } + + var cmd *exec.Cmd + switch installerType { + case TypeExe: + log.Infof("run exe installer: %s", installerFile) + cmd = exec.CommandContext(ctx, installerFile, "/S") + default: + installerDir := filepath.Dir(installerFile) + logPath := filepath.Join(installerDir, msiLogFile) + log.Infof("run msi installer: %s", installerFile) + cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath) + } + + cmd.Dir = filepath.Dir(installerFile) + + if resultErr = cmd.Start(); resultErr != nil { + log.Errorf("error starting installer: %v", resultErr) + return + } + + log.Infof("installer started with PID %d", cmd.Process.Pid) + if resultErr = cmd.Wait(); resultErr != nil { + log.Errorf("installer process finished with error: %v", resultErr) + return + } + + return nil +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Debugf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser(daemonFolder string) error { + uiPath := filepath.Join(daemonFolder, uiName) + log.Infof("starting netbird-ui: %s", uiPath) + + // Get the active console session ID + sessionID := windows.WTSGetActiveConsoleSessionId() + if sessionID == 0xFFFFFFFF { + return fmt.Errorf("no active user session found") + } + + // Get the user token for that session + var userToken windows.Token + err := windows.WTSQueryUserToken(sessionID, &userToken) + if err != nil { + return fmt.Errorf("failed to query user token: %w", err) + } + defer func() { + if err := userToken.Close(); err != nil { + log.Warnf("failed to close user token: %v", err) + } + }() + + // Duplicate the token to a primary token + var primaryToken windows.Token + err = windows.DuplicateTokenEx( + userToken, + windows.MAXIMUM_ALLOWED, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return fmt.Errorf("failed to duplicate token: %w", err) + } + defer func() { + if err := primaryToken.Close(); err != nil { + log.Warnf("failed to close token: %v", err) + } + }() + + // Prepare startup info + var si windows.StartupInfo + si.Cb = uint32(unsafe.Sizeof(si)) + si.Desktop = windows.StringToUTF16Ptr("winsta0\\default") + + var pi windows.ProcessInformation + + cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath)) + if err != nil { + return fmt.Errorf("failed to convert path to UTF16: %w", err) + } + + creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT + + err = windows.CreateProcessAsUser( + primaryToken, + nil, + cmdLine, + nil, + nil, + false, + creationFlags, + nil, + nil, + &si, + &pi, + ) + if err != nil { + return fmt.Errorf("CreateProcessAsUser failed: %w", err) + } + + // Close handles + if err := windows.CloseHandle(pi.Process); err != nil { + log.Warnf("failed to close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Warnf("failed to close thread handle: %v", err) + } + + log.Infof("netbird-ui started successfully in session %d", sessionID) + return nil +} + +func urlWithVersionArch(it Type, version string) string { + var url string + if it == TypeExe { + url = exeDownloadURL + } else { + url = msiDownloadURL + } + url = strings.ReplaceAll(url, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updatemanager/installer/log.go new file mode 100644 index 000000000..8b60dba28 --- /dev/null +++ b/client/internal/updatemanager/installer/log.go @@ -0,0 +1,5 @@ +package installer + +const ( + LogFile = "installer.log" +) diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updatemanager/installer/procattr_darwin.go new file mode 100644 index 000000000..56f2018bb --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_darwin.go @@ -0,0 +1,15 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run in a new session, +// making it independent of the parent daemon process. This ensures the updater +// survives when the daemon is stopped during the pkg installation. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } +} diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updatemanager/installer/procattr_windows.go new file mode 100644 index 000000000..29a8a2de0 --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_windows.go @@ -0,0 +1,14 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run detached from the parent, +// making it independent of the parent daemon process. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS + } +} diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updatemanager/installer/repourl_dev.go new file mode 100644 index 000000000..088821ad3 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_dev.go @@ -0,0 +1,7 @@ +//go:build devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo" +) diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updatemanager/installer/repourl_prod.go new file mode 100644 index 000000000..abddc62c1 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_prod.go @@ -0,0 +1,7 @@ +//go:build !devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures" +) diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updatemanager/installer/result.go new file mode 100644 index 000000000..03d08d527 --- /dev/null +++ b/client/internal/updatemanager/installer/result.go @@ -0,0 +1,230 @@ +package installer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" +) + +const ( + resultFile = "result.json" +) + +type Result struct { + Success bool + Error string + ExecutedAt time.Time +} + +// ResultHandler handles reading and writing update results +type ResultHandler struct { + resultFile string +} + +// NewResultHandler creates a new communicator with the given directory path +// The result file will be created as "result.json" in the specified directory +func NewResultHandler(installerDir string) *ResultHandler { + // Create it if it doesn't exist + // do not care if already exists + _ = os.MkdirAll(installerDir, 0o700) + + rh := &ResultHandler{ + resultFile: filepath.Join(installerDir, resultFile), + } + return rh +} + +func (rh *ResultHandler) GetErrorResultReason() string { + result, err := rh.tryReadResult() + if err == nil && !result.Success { + return result.Error + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return "" +} + +func (rh *ResultHandler) WriteSuccess() error { + result := Result{ + Success: true, + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) WriteErr(errReason error) error { + result := Result{ + Success: false, + Error: errReason.Error(), + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) { + log.Infof("start watching result: %s", rh.resultFile) + + // Check if file already exists (updater finished before we started watching) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + dir := filepath.Dir(rh.resultFile) + + if err := rh.waitForDirectory(ctx, dir); err != nil { + return Result{}, err + } + + return rh.watchForResultFile(ctx, dir) +} + +func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error { + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if info, err := os.Stat(dir); err == nil && info.IsDir() { + return nil + } + } + } +} + +func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Error(err) + return Result{}, err + } + + defer func() { + if err := watcher.Close(); err != nil { + log.Warnf("failed to close watcher: %v", err) + } + }() + + if err := watcher.Add(dir); err != nil { + return Result{}, fmt.Errorf("failed to watch directory: %v", err) + } + + // Check again after setting up watcher to avoid race condition + // (file could have been created between initial check and watcher setup) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + for { + select { + case <-ctx.Done(): + return Result{}, ctx.Err() + case event, ok := <-watcher.Events: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + + if result, done := rh.handleWatchEvent(event); done { + return result, nil + } + case err, ok := <-watcher.Errors: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + return Result{}, fmt.Errorf("watcher error: %w", err) + } + } +} + +func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) { + if event.Name != rh.resultFile { + return Result{}, false + } + + if event.Has(fsnotify.Create) { + result, err := rh.tryReadResult() + if err != nil { + log.Debugf("error while reading result: %v", err) + return result, true + } + log.Infof("installer result: %v", result) + return result, true + } + + return Result{}, false +} + +// Write writes the update result to a file for the UI to read +func (rh *ResultHandler) write(result Result) error { + log.Infof("write out installer result to: %s", rh.resultFile) + // Ensure directory exists + dir := filepath.Dir(rh.resultFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + log.Errorf("failed to create directory %s: %v", dir, err) + return err + } + + data, err := json.Marshal(result) + if err != nil { + return err + } + + // Write to a temporary file first, then rename for atomic operation + tmpPath := rh.resultFile + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + log.Errorf("failed to create temp file: %s", err) + return err + } + + // Atomic rename + if err := os.Rename(tmpPath, rh.resultFile); err != nil { + if cleanupErr := os.Remove(tmpPath); cleanupErr != nil { + log.Warnf("Failed to remove temp result file: %v", err) + } + return err + } + + return nil +} + +func (rh *ResultHandler) cleanup() error { + err := os.Remove(rh.resultFile) + if err != nil && !os.IsNotExist(err) { + return err + } + log.Debugf("delete installer result file: %s", rh.resultFile) + return nil +} + +// tryReadResult attempts to read and validate the result file +func (rh *ResultHandler) tryReadResult() (Result, error) { + data, err := os.ReadFile(rh.resultFile) + if err != nil { + return Result{}, err + } + + var result Result + if err := json.Unmarshal(data, &result); err != nil { + return Result{}, fmt.Errorf("invalid result format: %w", err) + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return result, nil +} diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updatemanager/installer/types.go new file mode 100644 index 000000000..656d84f88 --- /dev/null +++ b/client/internal/updatemanager/installer/types.go @@ -0,0 +1,14 @@ +package installer + +type Type struct { + name string + downloadable bool +} + +func (t Type) String() string { + return t.name +} + +func (t Type) Downloadable() bool { + return t.downloadable +} diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updatemanager/installer/types_darwin.go new file mode 100644 index 000000000..95a0cb737 --- /dev/null +++ b/client/internal/updatemanager/installer/types_darwin.go @@ -0,0 +1,22 @@ +package installer + +import ( + "context" + "os/exec" +) + +var ( + TypeHomebrew = Type{name: "Homebrew", downloadable: false} + TypePKG = Type{name: "pkg", downloadable: true} +) + +func TypeOfInstaller(ctx context.Context) Type { + cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client") + _, err := cmd.Output() + if err != nil && cmd.ProcessState.ExitCode() == 1 { + // Not installed using pkg file, thus installed using Homebrew + + return TypeHomebrew + } + return TypePKG +} diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updatemanager/installer/types_windows.go new file mode 100644 index 000000000..d4e5d83bd --- /dev/null +++ b/client/internal/updatemanager/installer/types_windows.go @@ -0,0 +1,51 @@ +package installer + +import ( + "context" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows/registry" +) + +const ( + uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` + uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` +) + +var ( + TypeExe = Type{name: "EXE", downloadable: true} + TypeMSI = Type{name: "MSI", downloadable: true} +) + +func TypeOfInstaller(_ context.Context) Type { + paths := []string{uninstallKeyPath64, uninstallKeyPath32} + + for _, path := range paths { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + continue + } + + if err := k.Close(); err != nil { + log.Warnf("Error closing registry key: %v", err) + } + return TypeExe + + } + + log.Debug("No registry entry found for Netbird, assuming MSI installation") + return TypeMSI +} + +func typeByFileExtension(filePath string) (Type, error) { + switch { + case strings.HasSuffix(strings.ToLower(filePath), ".exe"): + return TypeExe, nil + case strings.HasSuffix(strings.ToLower(filePath), ".msi"): + return TypeMSI, nil + default: + return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath) + } +} diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go new file mode 100644 index 000000000..eae11de56 --- /dev/null +++ b/client/internal/updatemanager/manager.go @@ -0,0 +1,374 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "time" + + v "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + cProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/version" +) + +const ( + latestVersion = "latest" + // this version will be ignored + developmentVersion = "development" +) + +var errNoUpdateState = errors.New("no update state found") + +type UpdateState struct { + PreUpdateVersion string + TargetVersion string +} + +func (u UpdateState) Name() string { + return "autoUpdate" +} + +type Manager struct { + statusRecorder *peer.Status + stateManager *statemanager.Manager + + lastTrigger time.Time + mgmUpdateChan chan struct{} + updateChannel chan struct{} + currentVersion string + update UpdateInterface + wg sync.WaitGroup + + cancel context.CancelFunc + + expectedVersion *v.Version + updateToLatestVersion bool + + // updateMutex protect update and expectedVersion fields + updateMutex sync.Mutex + + triggerUpdateFn func(context.Context, string) error +} + +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + if runtime.GOOS == "darwin" { + isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Home Brew installation") + return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") + } + } + return newManager(statusRecorder, stateManager) +} + +func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + manager := &Manager{ + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + } + manager.triggerUpdateFn = manager.triggerUpdate + + stateManager.RegisterState(&UpdateState{}) + + return manager, nil +} + +// CheckUpdateSuccess checks if the update was successful and send a notification. +// It works without to start the update manager. +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + reason := m.lastResultErrReason() + if reason != "" { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %s", reason), + nil, + ) + } + + updateState, err := m.loadAndDeleteUpdateState(ctx) + if err != nil { + if errors.Is(err, errNoUpdateState) { + return + } + log.Errorf("failed to load update state: %v", err) + return + } + + log.Debugf("auto-update state loaded, %v", *updateState) + + if updateState.TargetVersion == m.currentVersion { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "Auto-update completed", + fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion), + nil, + ) + return + } +} + +func (m *Manager) Start(ctx context.Context) { + if m.cancel != nil { + log.Errorf("Manager already started") + return + } + + m.update.SetDaemonVersion(m.currentVersion) + m.update.SetOnUpdateListener(func() { + select { + case m.updateChannel <- struct{}{}: + default: + } + }) + go m.update.StartFetcher() + + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + + m.wg.Add(1) + go m.updateLoop(ctx) +} + +func (m *Manager) SetVersion(expectedVersion string) { + log.Infof("set expected agent version for upgrade: %s", expectedVersion) + if m.cancel == nil { + log.Errorf("manager not started") + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + + if expectedVersion == "" { + log.Errorf("empty expected version provided") + m.expectedVersion = nil + m.updateToLatestVersion = false + return + } + + if expectedVersion == latestVersion { + m.updateToLatestVersion = true + m.expectedVersion = nil + } else { + expectedSemVer, err := v.NewVersion(expectedVersion) + if err != nil { + log.Errorf("error parsing version: %v", err) + return + } + if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) { + return + } + m.expectedVersion = expectedSemVer + m.updateToLatestVersion = false + } + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) Stop() { + if m.cancel == nil { + return + } + + m.cancel() + m.updateMutex.Lock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } + m.updateMutex.Unlock() + + m.wg.Wait() +} + +func (m *Manager) onContextCancel() { + if m.cancel == nil { + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } +} + +func (m *Manager) updateLoop(ctx context.Context) { + defer m.wg.Done() + + for { + select { + case <-ctx.Done(): + m.onContextCancel() + return + case <-m.mgmUpdateChan: + case <-m.updateChannel: + log.Infof("fetched new version info") + } + + m.handleUpdate(ctx) + } +} + +func (m *Manager) handleUpdate(ctx context.Context) { + var updateVersion *v.Version + + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + + expectedVersion := m.expectedVersion + useLatest := m.updateToLatestVersion + curLatestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + switch { + // Resolve "latest" to actual version + case useLatest: + if curLatestVersion == nil { + log.Tracef("latest version not fetched yet") + return + } + updateVersion = curLatestVersion + // Update to specific version + case expectedVersion != nil: + updateVersion = expectedVersion + default: + log.Debugf("no expected version information set") + return + } + + log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) + if !m.shouldUpdate(updateVersion) { + return + } + + m.lastTrigger = time.Now() + log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Automatically updating client", + "Your client version is older than auto-update version set in Management, updating client now.", + nil, + ) + + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "", + "", + map[string]string{"progress_window": "show", "version": updateVersion.String()}, + ) + + updateState := UpdateState{ + PreUpdateVersion: m.currentVersion, + TargetVersion: updateVersion.String(), + } + + if err := m.stateManager.UpdateState(updateState); err != nil { + log.Warnf("failed to update state: %v", err) + } else { + if err = m.stateManager.PersistState(ctx); err != nil { + log.Warnf("failed to persist state: %v", err) + } + } + + if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { + log.Errorf("Error triggering auto-update: %v", err) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %v", err), + nil, + ) + } +} + +// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. +// Returns nil if no state exists. +func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) { + stateType := &UpdateState{} + + m.stateManager.RegisterState(stateType) + if err := m.stateManager.LoadState(stateType); err != nil { + return nil, fmt.Errorf("load state: %w", err) + } + + state := m.stateManager.GetState(stateType) + if state == nil { + return nil, errNoUpdateState + } + + updateState, ok := state.(*UpdateState) + if !ok { + return nil, fmt.Errorf("failed to cast state to UpdateState") + } + + if err := m.stateManager.DeleteState(updateState); err != nil { + return nil, fmt.Errorf("delete state: %w", err) + } + + if err := m.stateManager.PersistState(ctx); err != nil { + return nil, fmt.Errorf("persist state: %w", err) + } + + return updateState, nil +} + +func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { + if m.currentVersion == developmentVersion { + log.Debugf("skipping auto-update, running development version") + return false + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil { + log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err) + return false + } + if currentVersion.GreaterThanOrEqual(updateVersion) { + log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion) + return false + } + + if time.Since(m.lastTrigger) < 5*time.Minute { + log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + return false + } + + return true +} + +func (m *Manager) lastResultErrReason() string { + inst := installer.New() + result := installer.NewResultHandler(inst.TempDir()) + return result.GetErrorResultReason() +} + +func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { + inst := installer.New() + return inst.RunInstallation(ctx, targetVersion) +} diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go new file mode 100644 index 000000000..20ddec10d --- /dev/null +++ b/client/internal/updatemanager/manager_test.go @@ -0,0 +1,214 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (v versionUpdateMock) StopWatch() {} + +func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + v.onUpdate = updateFn +} + +func (v versionUpdateMock) LatestVersion() *v.Version { + return v.latestVersion +} + +func (v versionUpdateMock) StartFetcher() {} + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should only trigger update once due to time between triggers being < 5 Minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: false, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = mockUpdate + + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion("latest") + var triggeredInit bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) + } + triggeredInit = true + case <-time.After(10 * time.Millisecond): + triggeredInit = false + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + var triggeredLater bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } + triggeredLater = true + case <-time.After(10 * time.Millisecond): + triggeredLater = false + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Update to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Update to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Update to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Update to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Update to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion(c.expectedVersion) + + var updateTriggered bool + select { + case targetVersion := <-targetVersionChan: + if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) + } + updateTriggered = true + case <-time.After(10 * time.Millisecond): + updateTriggered = false + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go new file mode 100644 index 000000000..4e87c2d77 --- /dev/null +++ b/client/internal/updatemanager/manager_unsupported.go @@ -0,0 +1,39 @@ +//go:build !windows && !darwin + +package updatemanager + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// Manager is a no-op stub for unsupported platforms +type Manager struct{} + +// NewManager returns a no-op manager for unsupported platforms +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + return nil, fmt.Errorf("update manager is not supported on this platform") +} + +// CheckUpdateSuccess is a no-op on unsupported platforms +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + // no-op +} + +// Start is a no-op on unsupported platforms +func (m *Manager) Start(ctx context.Context) { + // no-op +} + +// SetVersion is a no-op on unsupported platforms +func (m *Manager) SetVersion(expectedVersion string) { + // no-op +} + +// Stop is a no-op on unsupported platforms +func (m *Manager) Stop() { + // no-op +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updatemanager/reposign/artifact.go new file mode 100644 index 000000000..3d4fe9c74 --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact.go @@ -0,0 +1,302 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/blake2s" +) + +const ( + tagArtifactPrivate = "ARTIFACT PRIVATE KEY" + tagArtifactPublic = "ARTIFACT PUBLIC KEY" + + maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour + maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour +) + +// ArtifactHash wraps a hash.Hash and counts bytes written +type ArtifactHash struct { + hash.Hash +} + +// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s +func NewArtifactHash() *ArtifactHash { + h, err := blake2s.New256(nil) + if err != nil { + panic(err) // Should never happen with nil Key + } + return &ArtifactHash{Hash: h} +} + +func (ah *ArtifactHash) Write(b []byte) (int, error) { + return ah.Hash.Write(b) +} + +// ArtifactKey is a signing Key used to sign artifacts +type ArtifactKey struct { + PrivateKey +} + +func (k ArtifactKey) String() string { + return fmt.Sprintf( + "ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) { + // Verify root key is still valid + if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) { + return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339)) + } + + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err) + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + ak := &ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(ak.PrivateKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + // Sign the public key with the root key + signature, err := SignArtifactKey(*rootKey, pubPEM) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err) + } + + return ak, privPEM, pubPEM, signature, nil +} + +func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate) + if err != nil { + return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err) + } + return ArtifactKey{pk}, nil +} + +func ParseArtifactPubKey(data []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(data, tagArtifactPublic) + return pk, err +} + +func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) { + if len(keys) == 0 { + return nil, nil, errors.New("no keys to bundle") + } + + // Create bundle by concatenating PEM-encoded keys + var pubBundle []byte + + for _, pk := range keys { + // Marshal PublicKey struct to JSON + pubJSON, err := json.Marshal(pk) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + pubBundle = append(pubBundle, pubPEM...) + } + + // Sign the entire bundle with the root key + signature, err := SignArtifactKey(*rootKey, pubBundle) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err) + } + + return pubBundle, signature, nil +} + +func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) { + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge { + err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + + // Reconstruct the signed message: artifact_key_data || timestamp + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("failed to verify signature of artifact keys") + } + + pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic) + if err != nil { + log.Debugf("failed to parse public keys: %s", err) + return nil, err + } + + validKeys := make([]PublicKey, 0, len(pubKeys)) + for _, pubKey := range pubKeys { + // Filter out expired keys + if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) { + log.Debugf("Key %s is expired at %v (current time %v)", + pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now) + continue + } + + if revocationList != nil { + if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked { + log.Debugf("Key %s is revoked as of %v (created %v)", + pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt) + continue + } + } + validKeys = append(validKeys, pubKey) + } + + if len(validKeys) == 0 { + log.Debugf("no valid public keys found for artifact keys") + return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys)) + } + + return validKeys, nil +} + +func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error { + // Validate signature timestamp + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("failed to verify signature of artifact: %s", err) + return err + } + if now.Sub(signature.Timestamp) > maxArtifactSignatureAge { + return fmt.Errorf("artifact signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return fmt.Errorf("failed to hash artifact: %w", err) + } + hash := h.Sum(nil) + + // Reconstruct the signed message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + // Find matching Key and verify + for _, keyInfo := range artifactPubKeys { + if keyInfo.Metadata.ID == signature.KeyID { + // Check Key expiration + if !keyInfo.Metadata.ExpiresAt.IsZero() && + signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) { + return fmt.Errorf("signing Key %s expired at %v, signature from %v", + signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp) + } + + if ed25519.Verify(keyInfo.Key, msg, signature.Signature) { + log.Debugf("artifact verified successfully with Key: %s", signature.KeyID) + return nil + } + return fmt.Errorf("signature verification failed for Key %s", signature.KeyID) + } + } + + return fmt.Errorf("no signing Key found with ID %s", signature.KeyID) +} + +func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) { + if len(data) == 0 { // Check happens too late + return nil, fmt.Errorf("artifact length must be positive, got %d", len(data)) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return nil, fmt.Errorf("failed to write artifact hash: %w", err) + } + + timestamp := time.Now().UTC() + + if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) { + return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt) + } + + hash := h.Sum(nil) + + // Create message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(artifactKey.Key, msg) + + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: artifactKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updatemanager/reposign/artifact_test.go new file mode 100644 index 000000000..8865e2d0a --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact_test.go @@ -0,0 +1,1080 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactHash + +func TestNewArtifactHash(t *testing.T) { + h := NewArtifactHash() + assert.NotNil(t, h) + assert.NotNil(t, h.Hash) +} + +func TestArtifactHash_Write(t *testing.T) { + h := NewArtifactHash() + + data := []byte("test data") + n, err := h.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + + hash := h.Sum(nil) + assert.NotEmpty(t, hash) + assert.Equal(t, 32, len(hash)) // BLAKE2s-256 +} + +func TestArtifactHash_Deterministic(t *testing.T) { + data := []byte("test data") + + h1 := NewArtifactHash() + if _, err := h1.Write(data); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write(data); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.Equal(t, hash1, hash2) +} + +func TestArtifactHash_DifferentData(t *testing.T) { + h1 := NewArtifactHash() + if _, err := h1.Write([]byte("data1")); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write([]byte("data2")); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.NotEqual(t, hash1, hash2) +} + +// Test ArtifactKey.String() + +func TestArtifactKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + + ak := ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := ak.String() + assert.Contains(t, str, "ArtifactKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2025-01-15") +} + +// Test GenerateArtifactKey + +func TestGenerateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + }, + } + + // Generate artifact key + ak, privPEM, pubPEM, signature, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + assert.NotEmpty(t, signature) + + // Verify expiration + assert.True(t, ak.Metadata.ExpiresAt.After(time.Now())) + assert.True(t, ak.Metadata.ExpiresAt.Before(time.Now().Add(31*24*time.Hour))) +} + +func TestGenerateArtifactKey_ExpiredRoot(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create expired root key + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().Add(-2 * 365 * 24 * time.Hour).UTC(), + ExpiresAt: time.Now().Add(-1 * time.Hour).UTC(), // Expired + }, + }, + } + + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestGenerateArtifactKey_NoExpiration(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Root key with no expiration + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + ak, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) +} + +// Test ParseArtifactKey + +func TestParseArtifactKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Parse it back + parsed, err := ParseArtifactKey(privPEM) + require.NoError(t, err) + + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactKey_InvalidPEM(t *testing.T) { + _, err := ParseArtifactKey([]byte("invalid pem")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseArtifactKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a root key (wrong type) + rootKey := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + privJSON, err := json.Marshal(rootKey.PrivateKey) + require.NoError(t, err) + + privPEM := encodePrivateKey(privJSON, tagRootPrivate) + + _, err = ParseArtifactKey(privPEM) + assert.Error(t, err) +} + +// Test ParseArtifactPubKey + +func TestParseArtifactPubKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + parsed, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactPubKey_Invalid(t *testing.T) { + _, err := ParseArtifactPubKey([]byte("invalid")) + assert.Error(t, err) +} + +// Test BundleArtifactKeys + +func TestBundleArtifactKeys_Single(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, signature, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) +} + +func TestBundleArtifactKeys_Multiple(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate 3 artifact keys + var pubKeys []PublicKey + for i := 0; i < 3; i++ { + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + pubKeys = append(pubKeys, pubKey) + } + + bundle, signature, err := BundleArtifactKeys(rootKey, pubKeys) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) + + // Verify we can parse the bundle + parsed, err := parsePublicKeyBundle(bundle, tagArtifactPublic) + require.NoError(t, err) + assert.Len(t, parsed, 3) +} + +func TestBundleArtifactKeys_Empty(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, err = BundleArtifactKeys(rootKey, []PublicKey{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys") +} + +// Test ValidateArtifactKeys + +func TestSingleValidateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, sigData, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + sig, _ := ParseSignature(sigData) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, pubPEM, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + // Bundle and sign + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_FutureTimestamp(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifactKeys_TooOld(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifactKeys_InvalidSignature(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, _, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, invalidSig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to verify") +} + +func TestValidateArtifactKeys_WithRevocation(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Bundle both keys + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Create revocation list with first key revoked + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey1.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + // Validate - should only return second key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, pubKey2.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_AllRevoked(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Revoke the key + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + assert.Error(t, err) + assert.Contains(t, err.Error(), "revoked") +} + +// Test ValidateArtifact + +func TestValidateArtifact_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign some data + data := []byte("test artifact data") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Get public key for validation + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Validate + err = ValidateArtifact([]PublicKey{artifactPubKey}, data, *sig) + assert.NoError(t, err) +} + +func TestValidateArtifact_FutureTimestamp(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifact_TooOld(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifact_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for key to expire + time.Sleep(10 * time.Millisecond) + + // Try to sign - should succeed but with old timestamp + data := []byte("test data") + sigData, err := SignData(*artifactKey, data) + require.Error(t, err) // Key is expired, so signing should fail + assert.Contains(t, err.Error(), "expired") + assert.Nil(t, sigData) +} + +func TestValidateArtifact_WrongKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + artifactKey1, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactKey2, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign with key1 + data := []byte("test data") + sigData, err := SignData(*artifactKey1, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Try to validate with key2 only + artifactPubKey2 := PublicKey{ + Key: artifactKey2.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey2.Metadata, + } + + err = ValidateArtifact([]PublicKey{artifactPubKey2}, data, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestValidateArtifact_TamperedData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + sigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Try to validate with tampered data + tamperedData := []byte("tampered data") + err = ValidateArtifact([]PublicKey{artifactPubKey}, tamperedData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateArtifactKeys_TwoKeysOneExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two keys where one is expired + // Should return only the valid key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with very short expiration + _, _, expiredPubPEM, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + expiredPubKey, err := ParseArtifactPubKey(expiredPubPEM) + require.NoError(t, err) + + // Wait for first key to expire + time.Sleep(10 * time.Millisecond) + + // Generate second key with normal expiration + _, _, validPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + validPubKey, err := ParseArtifactPubKey(validPubPEM) + require.NoError(t, err) + + // Bundle both keys together + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{expiredPubKey, validPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should return only the valid key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, validPubKey.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_TwoKeysBothExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two expired keys + // Should fail because no valid keys remain + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + // Generate second key with very short expiration + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Wait for expire + time.Sleep(10 * time.Millisecond) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should fail because all keys are expired + keys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + assert.NoError(t, err) + assert.Len(t, keys, 1) +} + +// Test SignData + +func TestSignData_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestSignData_EmptyData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + _, err = SignData(*artifactKey, []byte{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") +} + +func TestSignData_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to sign with expired key + _, err = SignData(*artifactKey, []byte("data")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +// Integration test + +func TestArtifact_FullWorkflow(t *testing.T) { + // Step 1: Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 2: Generate artifact key + artifactKey, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Step 3: Create and validate key bundle + artifactPubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, bundleSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(bundleSig) + require.NoError(t, err) + + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + + // Step 4: Sign artifact data + artifactData := []byte("This is my artifact data that needs to be signed") + artifactSig, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 5: Validate artifact + parsedSig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + err = ValidateArtifact(validKeys, artifactData, *parsedSig) + assert.NoError(t, err) +} + +// Helper function for tests +func encodePrivateKey(jsonData []byte, typeTag string) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: typeTag, + Bytes: jsonData, + }) +} diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updatemanager/reposign/certs/root-pub.pem new file mode 100644 index 000000000..e7c2fd2c0 --- /dev/null +++ b/client/internal/updatemanager/reposign/certs/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn +V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz +X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updatemanager/reposign/certsdev/root-pub.pem new file mode 100644 index 000000000..f7145477b --- /dev/null +++ b/client/internal/updatemanager/reposign/certsdev/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr +ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz +X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updatemanager/reposign/doc.go new file mode 100644 index 000000000..660b9d11d --- /dev/null +++ b/client/internal/updatemanager/reposign/doc.go @@ -0,0 +1,174 @@ +// Package reposign implements a cryptographic signing and verification system +// for NetBird software update artifacts. It provides a hierarchical key +// management system with support for key rotation, revocation, and secure +// artifact distribution. +// +// # Architecture +// +// The package uses a two-tier key hierarchy: +// +// - Root Keys: Long-lived keys that sign artifact keys. These are embedded +// in the client binary and establish the root of trust. Root keys should +// be kept offline and highly secured. +// +// - Artifact Keys: Short-lived keys that sign release artifacts (binaries, +// packages, etc.). These are rotated regularly and can be revoked if +// compromised. Artifact keys are signed by root keys and distributed via +// a public repository. +// +// This separation allows for operational flexibility: artifact keys can be +// rotated frequently without requiring client updates, while root keys remain +// stable and embedded in the software. +// +// # Cryptographic Primitives +// +// The package uses strong, modern cryptographic algorithms: +// - Ed25519: Fast, secure digital signatures (no timing attacks) +// - BLAKE2s-256: Fast cryptographic hash for artifacts +// - SHA-256: Key ID generation +// - JSON: Structured key and signature serialization +// - PEM: Standard key encoding format +// +// # Security Features +// +// Timestamp Binding: +// - All signatures include cryptographically-bound timestamps +// - Prevents replay attacks and enforces signature freshness +// - Clock skew tolerance: 5 minutes +// +// Key Expiration: +// - All keys have expiration times +// - Expired keys are automatically rejected +// - Signing with an expired key fails immediately +// +// Key Revocation: +// - Compromised keys can be revoked via a signed revocation list +// - Revocation list is checked during artifact validation +// - Revoked keys are filtered out before artifact verification +// +// # File Structure +// +// The package expects the following file layout in the key repository: +// +// signrepo/ +// artifact-key-pub.pem # Bundle of artifact public keys +// artifact-key-pub.pem.sig # Root signature of the bundle +// revocation-list.json # List of revoked key IDs +// revocation-list.json.sig # Root signature of revocation list +// +// And in the artifacts repository: +// +// releases/ +// v0.28.0/ +// netbird-linux-amd64 +// netbird-linux-amd64.sig # Artifact signature +// netbird-darwin-amd64 +// netbird-darwin-amd64.sig +// ... +// +// # Embedded Root Keys +// +// Root public keys are embedded in the client binary at compile time: +// - Production keys: certs/ directory +// - Development keys: certsdev/ directory +// +// The build tag determines which keys are embedded: +// - Production builds: //go:build !devartifactsign +// - Development builds: //go:build devartifactsign +// +// This ensures that development artifacts cannot be verified using production +// keys and vice versa. +// +// # Key Rotation Strategies +// +// Root Key Rotation: +// +// Root keys can be rotated without breaking existing clients by leveraging +// the multi-key verification system. The loadEmbeddedPublicKeys function +// reads ALL files from the certs/ directory and accepts signatures from ANY +// of the embedded root keys. +// +// To rotate root keys: +// +// 1. Generate a new root key pair: +// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) +// +// 2. Add the new public key to the certs/ directory as a new file: +// certs/ +// root-pub-2024.pem # Old key (keep this!) +// root-pub-2025.pem # New key (add this) +// +// 3. Build new client versions with both keys embedded. The verification +// will accept signatures from either key. +// +// 4. Start signing new artifact keys with the new root key. Old clients +// with only the old root key will reject these, but new clients with +// both keys will accept them. +// +// Each file in certs/ can contain a single key or a bundle of keys (multiple +// PEM blocks). The system will parse all keys from all files and use them +// for verification. This provides maximum flexibility for key management. +// +// Important: Never remove all old root keys at once. Always maintain at least +// one overlapping key between releases to ensure smooth transitions. +// +// Artifact Key Rotation: +// +// Artifact keys should be rotated regularly (e.g., every 90 days) using the +// bundling mechanism. The BundleArtifactKeys function allows multiple artifact +// keys to be bundled together in a single signed package, and ValidateArtifact +// will accept signatures from ANY key in the bundle. +// +// To rotate artifact keys smoothly: +// +// 1. Generate a new artifact key while keeping the old one: +// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour) +// // Keep oldPubPEM and oldKey available +// +// 2. Create a bundle containing both old and new public keys +// +// 3. Upload the bundle and its signature to the key repository: +// signrepo/artifact-key-pub.pem # Contains both keys +// signrepo/artifact-key-pub.pem.sig # Root signature +// +// 4. Start signing new releases with the NEW key, but keep the bundle +// unchanged. Clients will download the bundle (containing both keys) +// and accept signatures from either key. +// +// Key bundle validation workflow: +// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig +// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key +// 3. ValidateArtifactKeys parses all public keys from the bundle +// 4. ValidateArtifactKeys filters out expired or revoked keys +// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds +// +// This multi-key acceptance model enables overlapping validity periods and +// smooth transitions without client update requirements. +// +// # Best Practices +// +// Root Key Management: +// - Generate root keys offline on an air-gapped machine +// - Store root private keys in hardware security modules (HSM) if possible +// - Use separate root keys for production and development +// - Rotate root keys infrequently (e.g., every 5-10 years) +// - Plan for root key rotation: embed multiple root public keys +// +// Artifact Key Management: +// - Rotate artifact keys regularly (e.g., every 90 days) +// - Use separate artifact keys for different release channels if needed +// - Revoke keys immediately upon suspected compromise +// - Bundle multiple artifact keys to enable smooth rotation +// +// Signing Process: +// - Sign artifacts in a secure CI/CD environment +// - Never commit private keys to version control +// - Use environment variables or secret management for keys +// - Verify signatures immediately after signing +// +// Distribution: +// - Serve keys and revocation lists from a reliable CDN +// - Use HTTPS for all key and artifact downloads +// - Monitor download failures and signature verification failures +// - Keep revocation list up to date +package reposign diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updatemanager/reposign/embed_dev.go new file mode 100644 index 000000000..ef8f77373 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_dev.go @@ -0,0 +1,10 @@ +//go:build devartifactsign + +package reposign + +import "embed" + +//go:embed certsdev +var embeddedCerts embed.FS + +const embeddedCertsDir = "certsdev" diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updatemanager/reposign/embed_prod.go new file mode 100644 index 000000000..91530e5f4 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_prod.go @@ -0,0 +1,10 @@ +//go:build !devartifactsign + +package reposign + +import "embed" + +//go:embed certs +var embeddedCerts embed.FS + +const embeddedCertsDir = "certs" diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updatemanager/reposign/key.go new file mode 100644 index 000000000..bedfef70d --- /dev/null +++ b/client/internal/updatemanager/reposign/key.go @@ -0,0 +1,171 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "time" +) + +const ( + maxClockSkew = 5 * time.Minute +) + +// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key) +type KeyID [8]byte + +// computeKeyID generates a unique ID from a public Key +func computeKeyID(pub ed25519.PublicKey) KeyID { + h := sha256.Sum256(pub) + var id KeyID + copy(id[:], h[:8]) + return id +} + +// MarshalJSON implements json.Marshaler for KeyID +func (k KeyID) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +// UnmarshalJSON implements json.Unmarshaler for KeyID +func (k *KeyID) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + parsed, err := ParseKeyID(s) + if err != nil { + return err + } + + *k = parsed + return nil +} + +// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID. +func ParseKeyID(s string) (KeyID, error) { + var id KeyID + if len(s) != 16 { + return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s)) + } + + b, err := hex.DecodeString(s) + if err != nil { + return id, fmt.Errorf("failed to decode KeyID: %w", err) + } + + copy(id[:], b) + return id, nil +} + +func (k KeyID) String() string { + return fmt.Sprintf("%x", k[:]) +} + +// KeyMetadata contains versioning and lifecycle information for a Key +type KeyMetadata struct { + ID KeyID `json:"id"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration +} + +// PublicKey wraps a public Key with its Metadata +type PublicKey struct { + Key ed25519.PublicKey + Metadata KeyMetadata +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) { + var keys []PublicKey + for len(bundle) > 0 { + keyInfo, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, keyInfo) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no keys found in bundle") + } + return keys, nil +} + +func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) { + b, rest := pem.Decode(data) + if b == nil { + return PublicKey{}, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pub PublicKey + if err := json.Unmarshal(b.Bytes, &pub); err != nil { + return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err) + } + + // Validate key length + if len(pub.Key) != ed25519.PublicKeySize { + return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d", + ed25519.PublicKeySize, len(pub.Key)) + } + + // Always recompute ID to ensure integrity + pub.Metadata.ID = computeKeyID(pub.Key) + + return pub, rest, nil +} + +type PrivateKey struct { + Key ed25519.PrivateKey + Metadata KeyMetadata +} + +func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return PrivateKey{}, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return PrivateKey{}, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pk PrivateKey + if err := json.Unmarshal(b.Bytes, &pk); err != nil { + return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err) + } + + // Validate key length + if len(pk.Key) != ed25519.PrivateKeySize { + return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d", + ed25519.PrivateKeySize, len(pk.Key)) + } + + return pk, nil +} + +func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool { + // Verify with root keys + var rootKeys []ed25519.PublicKey + for _, r := range publicRootKeys { + rootKeys = append(rootKeys, r.Key) + } + + for _, k := range rootKeys { + if ed25519.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updatemanager/reposign/key_test.go new file mode 100644 index 000000000..f8e1676fb --- /dev/null +++ b/client/internal/updatemanager/reposign/key_test.go @@ -0,0 +1,636 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test KeyID functions + +func TestComputeKeyID(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + + // Verify it's the first 8 bytes of SHA-256 + h := sha256.Sum256(pub) + expectedID := KeyID{} + copy(expectedID[:], h[:8]) + + assert.Equal(t, expectedID, keyID) +} + +func TestComputeKeyID_Deterministic(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Computing KeyID multiple times should give the same result + keyID1 := computeKeyID(pub) + keyID2 := computeKeyID(pub) + + assert.Equal(t, keyID1, keyID2) +} + +func TestComputeKeyID_DifferentKeys(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + // Different keys should produce different IDs + assert.NotEqual(t, keyID1, keyID2) +} + +func TestParseKeyID_Valid(t *testing.T) { + hexStr := "0123456789abcdef" + + keyID, err := ParseKeyID(hexStr) + require.NoError(t, err) + + expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + assert.Equal(t, expected, keyID) +} + +func TestParseKeyID_InvalidLength(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"too short", "01234567"}, + {"too long", "0123456789abcdef00"}, + {"empty", ""}, + {"odd length", "0123456789abcde"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseKeyID(tt.input) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid KeyID length") + }) + } +} + +func TestParseKeyID_InvalidHex(t *testing.T) { + invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex + + _, err := ParseKeyID(invalidHex) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode KeyID") +} + +func TestKeyID_String(t *testing.T) { + keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + + str := keyID.String() + assert.Equal(t, "0123456789abcdef", str) +} + +func TestKeyID_RoundTrip(t *testing.T) { + original := "fedcba9876543210" + + keyID, err := ParseKeyID(original) + require.NoError(t, err) + + result := keyID.String() + assert.Equal(t, original, result) +} + +func TestKeyID_ZeroValue(t *testing.T) { + keyID := KeyID{} + str := keyID.String() + assert.Equal(t, "0000000000000000", str) +} + +// Test KeyMetadata + +func TestKeyMetadata_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC), + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, metadata.ID, decoded.ID) + assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix()) + assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix()) +} + +func TestKeyMetadata_NoExpiration(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Time{}, // Zero value = no expiration + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.True(t, decoded.ExpiresAt.IsZero()) +} + +// Test PublicKey + +func TestPublicKey_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + var decoded PublicKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, pubKey.Key, decoded.Key) + assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePublicKey + +func TestParsePublicKey_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + } + + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + + // Marshal to JSON + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode to PEM + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse it back + parsed, rest, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + assert.Empty(t, rest) + assert.Equal(t, pub, parsed.Key) + assert.Equal(t, metadata.ID, parsed.Metadata.ID) +} + +func TestParsePublicKey_InvalidPEM(t *testing.T) { + invalidPEM := []byte("not a PEM") + + _, _, err := parsePublicKey(invalidPEM, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePublicKey_WrongType(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode with wrong type + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePublicKey_InvalidJSON(t *testing.T) { + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: []byte("invalid json"), + }) + + _, _, err := parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParsePublicKey_InvalidKeySize(t *testing.T) { + // Create a public key with wrong size + pubKey := PublicKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 public key size") +} + +func TestParsePublicKey_IDRecomputation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a public key with WRONG ID + wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: wrongID, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse should recompute the correct ID + parsed, _, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + + correctID := computeKeyID(pub) + assert.Equal(t, correctID, parsed.Metadata.ID) + assert.NotEqual(t, wrongID, parsed.Metadata.ID) +} + +// Test parsePublicKeyBundle + +func TestParsePublicKeyBundle_Single(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + keys, err := parsePublicKeyBundle(pemData, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 1) + assert.Equal(t, pub, keys[0].Key) +} + +func TestParsePublicKeyBundle_Multiple(t *testing.T) { + var bundle []byte + + // Create 3 keys + for i := 0; i < 3; i++ { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + bundle = append(bundle, pemData...) + } + + keys, err := parsePublicKeyBundle(bundle, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 3) +} + +func TestParsePublicKeyBundle_Empty(t *testing.T) { + _, err := parsePublicKeyBundle([]byte{}, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys found") +} + +func TestParsePublicKeyBundle_Invalid(t *testing.T) { + _, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic) + assert.Error(t, err) +} + +// Test PrivateKey + +func TestPrivateKey_JSONMarshaling(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + var decoded PrivateKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, privKey.Key, decoded.Key) + assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePrivateKey + +func TestParsePrivateKey_Valid(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + parsed, err := parsePrivateKey(pemData, tagRootPrivate) + require.NoError(t, err) + assert.Equal(t, priv, parsed.Key) +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePrivateKey_TrailingData(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + // Add trailing data + pemData = append(pemData, []byte("extra data")...) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "trailing PEM data") +} + +func TestParsePrivateKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePrivateKey_InvalidKeySize(t *testing.T) { + privKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +// Test verifyAny + +func TestVerifyAny_ValidSignature(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_InvalidSignature(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + invalidSignature := make([]byte, ed25519.SignatureSize) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, invalidSignature) + assert.False(t, result) +} + +func TestVerifyAny_MultipleKeys(t *testing.T) { + // Create 3 key pairs + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub3, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + {Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle + {Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_NoMatchingKey(t *testing.T) { + _, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + // Only include pub2, not pub1 + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_EmptyKeys(t *testing.T) { + message := []byte("test message") + signature := make([]byte, ed25519.SignatureSize) + + result := verifyAny([]PublicKey{}, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_TamperedMessage(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + {Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}}, + } + + // Verify with different message + tamperedMessage := []byte("different message") + result := verifyAny(rootKeys, tamperedMessage, signature) + assert.False(t, result) +} diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updatemanager/reposign/revocation.go new file mode 100644 index 000000000..e679e212f --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation.go @@ -0,0 +1,229 @@ +package reposign + +import ( + "crypto/ed25519" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour + defaultRevocationListExpiration = 365 * 24 * time.Hour +) + +type RevocationList struct { + Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time + LastUpdated time.Time `json:"last_updated"` // When the list was last modified + ExpiresAt time.Time `json:"expires_at"` // When the list expires +} + +func (rl RevocationList) MarshalJSON() ([]byte, error) { + // Convert map[KeyID]time.Time to map[string]time.Time + strMap := make(map[string]time.Time, len(rl.Revoked)) + for k, v := range rl.Revoked { + strMap[k.String()] = v + } + + return json.Marshal(map[string]interface{}{ + "revoked": strMap, + "last_updated": rl.LastUpdated, + "expires_at": rl.ExpiresAt, + }) +} + +func (rl *RevocationList) UnmarshalJSON(data []byte) error { + var temp struct { + Revoked map[string]time.Time `json:"revoked"` + LastUpdated time.Time `json:"last_updated"` + ExpiresAt time.Time `json:"expires_at"` + Version int `json:"version"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Convert map[string]time.Time back to map[KeyID]time.Time + rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked)) + for k, v := range temp.Revoked { + kid, err := ParseKeyID(k) + if err != nil { + return fmt.Errorf("failed to parse KeyID %q: %w", k, err) + } + rl.Revoked[kid] = v + } + + rl.LastUpdated = temp.LastUpdated + rl.ExpiresAt = temp.ExpiresAt + + return nil +} + +func ParseRevocationList(data []byte) (*RevocationList, error) { + var rl RevocationList + if err := json.Unmarshal(data, &rl); err != nil { + return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err) + } + + // Initialize the map if it's nil (in case of empty JSON object) + if rl.Revoked == nil { + rl.Revoked = make(map[KeyID]time.Time) + } + + if rl.LastUpdated.IsZero() { + return nil, fmt.Errorf("revocation list missing last_updated timestamp") + } + + if rl.ExpiresAt.IsZero() { + return nil, fmt.Errorf("revocation list missing expires_at timestamp") + } + + return &rl, nil +} + +func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) { + revoList, err := ParseRevocationList(data) + if err != nil { + log.Debugf("failed to parse revocation list: %s", err) + return nil, err + } + + now := time.Now().UTC() + + // Validate signature timestamp + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + if now.Sub(signature.Timestamp) > maxRevocationSignatureAge { + err := fmt.Errorf("revocation list signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + // Ensure LastUpdated is not in the future (with clock skew tolerance) + if revoList.LastUpdated.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated) + log.Errorf("rejecting future-dated revocation list: %v", err) + return nil, err + } + + // Check if the revocation list has expired + if now.After(revoList.ExpiresAt) { + err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now) + log.Errorf("rejecting expired revocation list: %v", err) + return nil, err + } + + // Ensure ExpiresAt is not in the future by more than the expected expiration window + // (allows some clock skew but prevents maliciously long expiration times) + if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) { + err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt) + log.Errorf("rejecting revocation list with invalid expiration: %v", err) + return nil, err + } + + // Validate signature timestamp is close to LastUpdated + // (prevents signing old lists with new timestamps) + timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs() + if timeDiff > maxClockSkew { + err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)", + signature.Timestamp, revoList.LastUpdated, timeDiff) + log.Errorf("timestamp mismatch in revocation list: %v", err) + return nil, err + } + + // Reconstruct the signed message: revocation_list_data || timestamp || version + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("revocation list verification failed") + } + return revoList, nil +} + +func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.UTC(), + ExpiresAt: now.Add(expiration).UTC(), + } + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now().UTC() + + rl.Revoked[kid] = now + rl.LastUpdated = now + rl.ExpiresAt = now.Add(expiration) + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) { + data, err := json.Marshal(rl) + if err != nil { + return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err) + } + + timestamp := time.Now().UTC() + + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(privateRootKey.Key, msg) + + signature := &Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: privateRootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return signature, nil +} diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updatemanager/reposign/revocation_test.go new file mode 100644 index 000000000..d6d748f3d --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation_test.go @@ -0,0 +1,860 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RevocationList marshaling/unmarshaling + +func TestRevocationList_MarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC) + + rl := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: expiresAt, + } + + jsonData, err := json.Marshal(rl) + require.NoError(t, err) + + // Verify it can be unmarshaled back + var decoded map[string]interface{} + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Contains(t, decoded, "revoked") + assert.Contains(t, decoded, "last_updated") + assert.Contains(t, decoded, "expires_at") +} + +func TestRevocationList_UnmarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + jsonData := map[string]interface{}{ + "revoked": map[string]string{ + keyID.String(): revokedTime.Format(time.RFC3339), + }, + "last_updated": lastUpdated.Format(time.RFC3339), + } + + jsonBytes, err := json.Marshal(jsonData) + require.NoError(t, err) + + var rl RevocationList + err = json.Unmarshal(jsonBytes, &rl) + require.NoError(t, err) + + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, keyID) + assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix()) +} + +func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + original := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC), + }, + LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC), + } + + // Marshal + jsonData, err := original.MarshalJSON() + require.NoError(t, err) + + // Unmarshal + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + // Verify + assert.Len(t, decoded.Revoked, 2) + assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix()) + assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix()) + assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix()) +} + +func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) { + jsonData := []byte(`{ + "revoked": { + "invalid_key_id": "2024-01-15T10:30:00Z" + }, + "last_updated": "2024-01-15T11:00:00Z" + }`) + + var rl RevocationList + err := json.Unmarshal(jsonData, &rl) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse KeyID") +} + +func TestRevocationList_EmptyRevoked(t *testing.T) { + rl := &RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC(), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + assert.Empty(t, decoded.Revoked) + assert.NotNil(t, decoded.Revoked) +} + +// Test ParseRevocationList + +func TestParseRevocationList_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + rl := RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed) + assert.Len(t, parsed.Revoked, 1) + assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix()) +} + +func TestParseRevocationList_InvalidJSON(t *testing.T) { + invalidJSON := []byte("not valid json") + + _, err := ParseRevocationList(invalidJSON) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParseRevocationList_MissingLastUpdated(t *testing.T) { + jsonData := []byte(`{ + "revoked": {} + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_EmptyObject(t *testing.T) { + jsonData := []byte(`{}`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_NilRevoked(t *testing.T) { + lastUpdated := time.Now().UTC() + expiresAt := lastUpdated.Add(90 * 24 * time.Hour) + jsonData := []byte(`{ + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `", + "expires_at": "` + expiresAt.Format(time.RFC3339) + `" + }`) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed.Revoked) + assert.Empty(t, parsed.Revoked) +} + +func TestParseRevocationList_MissingExpiresAt(t *testing.T) { + lastUpdated := time.Now().UTC() + jsonData := []byte(`{ + "revoked": {}, + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `" + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing expires_at") +} + +// Test ValidateRevocationList + +func TestValidateRevocationList_Valid(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + rl, err := ValidateRevocationList(rootKeys, rlData, *signature) + require.NoError(t, err) + assert.NotNil(t, rl) + assert.Empty(t, rl.Revoked) +} + +func TestValidateRevocationList_InvalidSignature(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Validate should fail + _, err = ValidateRevocationList(rootKeys, rlData, invalidSig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateRevocationList_FutureTimestamp(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be in the future + signature.Timestamp = time.Now().UTC().Add(10 * time.Minute) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateRevocationList_TooOld(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be too old + signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateRevocationList_InvalidJSON(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + signature := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + _, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature) + assert.Error(t, err) +} + +func TestValidateRevocationList_FutureLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with future LastUpdated + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(10 * time.Minute), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "LastUpdated is in the future") +} + +func TestValidateRevocationList_TimestampMismatch(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with LastUpdated far in the past + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(-1 * time.Hour), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it with current timestamp + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + // Modify signature timestamp to differ too much from LastUpdated + sig.Timestamp = time.Now().UTC() + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "differs too much") +} + +func TestValidateRevocationList_Expired(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list that expired in the past + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.Add(-100 * 24 * time.Hour), + ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + // Adjust signature timestamp to match LastUpdated + sig.Timestamp = rl.LastUpdated + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge) + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now, + ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too far in the future") +} + +// Test CreateRevocationList + +func TestCreateRevocationList_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + assert.NotEmpty(t, rlData) + assert.NotEmpty(t, sigData) + + // Verify it can be parsed + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + assert.False(t, rl.LastUpdated.IsZero()) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +// Test ExtendRevocationList + +func TestExtendRevocationList_AddKey(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Generate a key to revoke + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + // Extend the revocation list + newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Verify the new list + newRL, err := ParseRevocationList(newRLData) + require.NoError(t, err) + assert.Len(t, newRL.Revoked, 1) + assert.Contains(t, newRL.Revoked, revokedKeyID) + + // Verify signature + sig, err := ParseSignature(newSigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestExtendRevocationList_MultipleKeys(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add first key + key1Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key1ID := computeKeyID(key1Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // Add second key + key2Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key2ID := computeKeyID(key2Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 2) + assert.Contains(t, rl.Revoked, key1ID) + assert.Contains(t, rl.Revoked, key2ID) +} + +func TestExtendRevocationList_DuplicateKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add a key + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + firstRevocationTime := rl.Revoked[keyID] + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Add the same key again + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // The revocation time should be updated + secondRevocationTime := rl.Revoked[keyID] + assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime)) +} + +func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + firstLastUpdated := rl.LastUpdated + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Extend list + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + + // LastUpdated should be updated + assert.True(t, rl.LastUpdated.After(firstLastUpdated)) +} + +// Integration test + +func TestRevocationList_FullWorkflow(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 1: Create empty revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 2: Validate it + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + rl, err := ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Step 3: Revoke a key + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Validate the extended list + sig, err = ParseSignature(sigData) + require.NoError(t, err) + + rl, err = ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, revokedKeyID) + + // Step 5: Verify the revocation time is reasonable + revTime := rl.Revoked[revokedKeyID] + now := time.Now().UTC() + assert.True(t, revTime.Before(now) || revTime.Equal(now)) + assert.True(t, now.Sub(revTime) < time.Minute) +} diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updatemanager/reposign/root.go new file mode 100644 index 000000000..2c3ca54a0 --- /dev/null +++ b/client/internal/updatemanager/reposign/root.go @@ -0,0 +1,120 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "fmt" + "time" +) + +const ( + tagRootPrivate = "ROOT PRIVATE KEY" + tagRootPublic = "ROOT PUBLIC KEY" +) + +// RootKey is a root Key used to sign signing keys +type RootKey struct { + PrivateKey +} + +func (k RootKey) String() string { + return fmt.Sprintf( + "RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func ParseRootKey(privKeyPEM []byte) (*RootKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root Key: %w", err) + } + return &RootKey{pk}, nil +} + +// ParseRootPublicKey parses a root public key from PEM format +func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err) + } + return pk, nil +} + +// GenerateRootKey generates a new root Key pair with Metadata +func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) { + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, err + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + rk := &RootKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(rk.PrivateKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: pubJSON, + }) + + return rk, privPEM, pubPEM, nil +} + +func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) { + timestamp := time.Now().UTC() + + // This ensures the timestamp is cryptographically bound to the signature + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(rootKey.Key, msg) + // Create signature bundle with timestamp and Metadata + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: rootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updatemanager/reposign/root_test.go new file mode 100644 index 000000000..e75e29729 --- /dev/null +++ b/client/internal/updatemanager/reposign/root_test.go @@ -0,0 +1,476 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RootKey.String() + +func TestRootKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2034-01-15") +} + +func TestRootKey_String_NoExpiration(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, "0001-01-01") // Zero time format +} + +// Test GenerateRootKey + +func TestGenerateRootKey_Valid(t *testing.T) { + expiration := 10 * 365 * 24 * time.Hour // 10 years + + rk, privPEM, pubPEM, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Verify the key has correct metadata + assert.False(t, rk.Metadata.CreatedAt.IsZero()) + assert.False(t, rk.Metadata.ExpiresAt.IsZero()) + assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt)) + + // Verify expiration is approximately correct + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ShortExpiration(t *testing.T) { + expiration := 24 * time.Hour // 1 day + + rk, _, _, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + + // Verify expiration + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ZeroExpiration(t *testing.T) { + rk, _, _, err := GenerateRootKey(0) + require.NoError(t, err) + assert.NotNil(t, rk) + + // With zero expiration, ExpiresAt should be equal to CreatedAt + assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt) +} + +func TestGenerateRootKey_PEMFormat(t *testing.T) { + rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Verify private key PEM + privBlock, _ := pem.Decode(privPEM) + require.NotNil(t, privBlock) + assert.Equal(t, tagRootPrivate, privBlock.Type) + + var privKey PrivateKey + err = json.Unmarshal(privBlock.Bytes, &privKey) + require.NoError(t, err) + assert.Equal(t, rk.Key, privKey.Key) + + // Verify public key PEM + pubBlock, _ := pem.Decode(pubPEM) + require.NotNil(t, pubBlock) + assert.Equal(t, tagRootPublic, pubBlock.Type) + + var pubKey PublicKey + err = json.Unmarshal(pubBlock.Bytes, &pubKey) + require.NoError(t, err) + assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID) +} + +func TestGenerateRootKey_KeySize(t *testing.T) { + rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Ed25519 private key should be 64 bytes + assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key)) + + // Ed25519 public key should be 32 bytes + pubKey := rk.Key.Public().(ed25519.PublicKey) + assert.Equal(t, ed25519.PublicKeySize, len(pubKey)) +} + +func TestGenerateRootKey_UniqueKeys(t *testing.T) { + rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Different keys should have different IDs + assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID) + assert.NotEqual(t, rk1.Key, rk2.Key) +} + +// Test ParseRootKey + +func TestParseRootKey_Valid(t *testing.T) { + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.NotNil(t, parsed) + + // Verify the parsed key matches the original + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) + assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix()) + assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix()) +} + +func TestParseRootKey_InvalidPEM(t *testing.T) { + _, err := ParseRootKey([]byte("not a valid PEM")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseRootKey_EmptyData(t *testing.T) { + _, err := ParseRootKey([]byte{}) + assert.Error(t, err) +} + +func TestParseRootKey_WrongType(t *testing.T) { + // Generate an artifact key instead of root key + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Try to parse artifact key as root key + _, err = ParseRootKey(privPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") + + // Just to use artifactKey to avoid unused variable warning + _ = artifactKey +} + +func TestParseRootKey_CorruptedJSON(t *testing.T) { + // Create PEM with corrupted JSON + corruptedPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: []byte("corrupted json data"), + }) + + _, err := ParseRootKey(corruptedPEM) + assert.Error(t, err) +} + +func TestParseRootKey_InvalidKeySize(t *testing.T) { + // Create a key with invalid size + invalidKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + privJSON, err := json.Marshal(invalidKey) + require.NoError(t, err) + + invalidPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + _, err = ParseRootKey(invalidPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +func TestParseRootKey_Roundtrip(t *testing.T) { + // Generate a key + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse it + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + + // Generate PEM again from parsed key + privJSON2, err := json.Marshal(parsed.PrivateKey) + require.NoError(t, err) + + privPEM2 := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON2, + }) + + // Parse again + parsed2, err := ParseRootKey(privPEM2) + require.NoError(t, err) + + // Should still match original + assert.Equal(t, original.Key, parsed2.Key) + assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID) +} + +// Test SignArtifactKey + +func TestSignArtifactKey_Valid(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Parse and verify signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, rootKey.Metadata.ID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "sha512", sig.HashAlgo) + assert.False(t, sig.Timestamp.IsZero()) +} + +func TestSignArtifactKey_EmptyData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + sigData, err := SignArtifactKey(*rootKey, []byte{}) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Should still be able to parse + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_Verify(t *testing.T) { + rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse public key + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + // Sign some data + data := []byte("test data for verification") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Parse signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Reconstruct message + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify signature + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid) +} + +func TestSignArtifactKey_DifferentData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data1 := []byte("data1") + data2 := []byte("data2") + + sig1, err := SignArtifactKey(*rootKey, data1) + require.NoError(t, err) + + sig2, err := SignArtifactKey(*rootKey, data2) + require.NoError(t, err) + + // Different data should produce different signatures + assert.NotEqual(t, sig1, sig2) +} + +func TestSignArtifactKey_MultipleSignatures(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data") + + // Sign twice with a small delay + sig1, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + sig2, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Signatures should be different due to different timestamps + assert.NotEqual(t, sig1, sig2) + + // Parse both signatures + parsed1, err := ParseSignature(sig1) + require.NoError(t, err) + + parsed2, err := ParseSignature(sig2) + require.NoError(t, err) + + // Timestamps should be different + assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp)) +} + +func TestSignArtifactKey_LargeData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Create 1MB of data + largeData := make([]byte, 1024*1024) + for i := range largeData { + largeData[i] = byte(i % 256) + } + + sigData, err := SignArtifactKey(*rootKey, largeData) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_TimestampInSignature(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + beforeSign := time.Now().UTC() + data := []byte("test data") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + afterSign := time.Now().UTC() + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Timestamp should be between before and after + assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second))) + assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second))) +} + +// Integration test + +func TestRootKey_FullWorkflow(t *testing.T) { + // Step 1: Generate root key + rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + assert.NotNil(t, rootKey) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Step 2: Parse the private key back + parsedRootKey, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.Equal(t, rootKey.Key, parsedRootKey.Key) + assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID) + + // Step 3: Generate an artifact key using root key + artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, artifactKey) + + // Step 4: Verify the artifact key signature + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + sig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic) + require.NoError(t, err) + + // Reconstruct message - SignArtifactKey signs the PEM, not the JSON + msg := make([]byte, 0, len(artifactPubPEM)+8) + msg = append(msg, artifactPubPEM...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify with root public key + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid, "Artifact key signature should be valid") + + // Step 5: Use artifact key to sign data + testData := []byte("This is test artifact data") + dataSig, err := SignData(*artifactKey, testData) + require.NoError(t, err) + assert.NotEmpty(t, dataSig) + + // Step 6: Verify the artifact data signature + dataSigParsed, err := ParseSignature(dataSig) + require.NoError(t, err) + + err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed) + assert.NoError(t, err, "Artifact data signature should be valid") +} + +func TestRootKey_ExpiredKeyWorkflow(t *testing.T) { + // Generate a root key that expires very soon + rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to generate artifact key with expired root key + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updatemanager/reposign/signature.go new file mode 100644 index 000000000..c7f06e94e --- /dev/null +++ b/client/internal/updatemanager/reposign/signature.go @@ -0,0 +1,24 @@ +package reposign + +import ( + "encoding/json" + "time" +) + +// Signature contains a signature with associated Metadata +type Signature struct { + Signature []byte `json:"signature"` + Timestamp time.Time `json:"timestamp"` + KeyID KeyID `json:"key_id"` + Algorithm string `json:"algorithm"` // "ed25519" + HashAlgo string `json:"hash_algo"` // "blake2s" or sha512 +} + +func ParseSignature(data []byte) (*Signature, error) { + var signature Signature + if err := json.Unmarshal(data, &signature); err != nil { + return nil, err + } + + return &signature, nil +} diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updatemanager/reposign/signature_test.go new file mode 100644 index 000000000..1960c5518 --- /dev/null +++ b/client/internal/updatemanager/reposign/signature_test.go @@ -0,0 +1,277 @@ +package reposign + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseSignature_Valid(t *testing.T) { + timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + keyID, err := ParseKeyID("0123456789abcdef") + require.NoError(t, err) + + signatureData := []byte{0x01, 0x02, 0x03, 0x04} + + jsonData, err := json.Marshal(Signature{ + Signature: signatureData, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + }) + require.NoError(t, err) + + sig, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Equal(t, signatureData, sig.Signature) + assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix()) + assert.Equal(t, keyID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestParseSignature_InvalidJSON(t *testing.T) { + invalidJSON := []byte(`{invalid json}`) + + sig, err := ParseSignature(invalidJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_EmptyData(t *testing.T) { + emptyJSON := []byte(`{}`) + + sig, err := ParseSignature(emptyJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Empty(t, sig.Signature) + assert.True(t, sig.Timestamp.IsZero()) + assert.Equal(t, KeyID{}, sig.KeyID) + assert.Empty(t, sig.Algorithm) + assert.Empty(t, sig.HashAlgo) +} + +func TestParseSignature_MissingFields(t *testing.T) { + // JSON with only some fields + partialJSON := []byte(`{ + "signature": "AQIDBA==", + "algorithm": "ed25519" + }`) + + sig, err := ParseSignature(partialJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.True(t, sig.Timestamp.IsZero()) +} + +func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) { + timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC) + keyID, err := ParseKeyID("fedcba9876543210") + require.NoError(t, err) + + original := Signature{ + Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Marshal + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // Verify + assert.Equal(t, original.Signature, parsed.Signature) + assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix()) + assert.Equal(t, original.KeyID, parsed.KeyID) + assert.Equal(t, original.Algorithm, parsed.Algorithm) + assert.Equal(t, original.HashAlgo, parsed.HashAlgo) +} + +func TestSignature_NilSignatureBytes(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("0011223344556677") + require.NoError(t, err) + + sig := Signature{ + Signature: nil, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Nil(t, parsed.Signature) +} + +func TestSignature_LargeSignature(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("aabbccddeeff0011") + require.NoError(t, err) + + // Create a large signature (64 bytes for ed25519) + largeSignature := make([]byte, 64) + for i := range largeSignature { + largeSignature[i] = byte(i) + } + + sig := Signature{ + Signature: largeSignature, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, largeSignature, parsed.Signature) +} + +func TestSignature_WithDifferentHashAlgorithms(t *testing.T) { + tests := []struct { + name string + hashAlgo string + }{ + {"blake2s", "blake2s"}, + {"sha512", "sha512"}, + {"sha256", "sha256"}, + {"empty", ""}, + } + + keyID, err := ParseKeyID("1122334455667788") + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sig := Signature{ + Signature: []byte{0x01, 0x02}, + Timestamp: time.Now().UTC(), + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: tt.hashAlgo, + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, tt.hashAlgo, parsed.HashAlgo) + }) + } +} + +func TestSignature_TimestampPrecision(t *testing.T) { + // Test that timestamp preserves precision through JSON marshaling + timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC) + keyID, err := ParseKeyID("8877665544332211") + require.NoError(t, err) + + sig := Signature{ + Signature: []byte{0xaa, 0xbb}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // JSON timestamps typically have second or millisecond precision + // so we check that at least seconds match + assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix()) +} + +func TestParseSignature_MalformedKeyID(t *testing.T) { + // Test with a malformed KeyID field + malformedJSON := []byte(`{ + "signature": "AQID", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "invalid_keyid_format", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + // This should fail since "invalid_keyid_format" is not a valid KeyID + sig, err := ParseSignature(malformedJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_InvalidTimestamp(t *testing.T) { + // Test with an invalid timestamp format + invalidTimestampJSON := []byte(`{ + "signature": "AQID", + "timestamp": "not-a-timestamp", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + sig, err := ParseSignature(invalidTimestampJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestSignature_ZeroKeyID(t *testing.T) { + // Test with a zero KeyID + sig := Signature{ + Signature: []byte{0x01, 0x02, 0x03}, + Timestamp: time.Now().UTC(), + KeyID: KeyID{}, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, KeyID{}, parsed.KeyID) +} + +func TestParseSignature_ExtraFields(t *testing.T) { + // JSON with extra fields that should be ignored + jsonWithExtra := []byte(`{ + "signature": "AQIDBA==", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s", + "extra_field": "should be ignored", + "another_extra": 12345 + }`) + + sig, err := ParseSignature(jsonWithExtra) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updatemanager/reposign/verify.go new file mode 100644 index 000000000..0af2a8c9e --- /dev/null +++ b/client/internal/updatemanager/reposign/verify.go @@ -0,0 +1,187 @@ +package reposign + +import ( + "context" + "fmt" + "net/url" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" +) + +const ( + artifactPubKeysFileName = "artifact-key-pub.pem" + artifactPubKeysSigFileName = "artifact-key-pub.pem.sig" + revocationFileName = "revocation-list.json" + revocationSignFileName = "revocation-list.json.sig" + + keySizeLimit = 5 * 1024 * 1024 //5MB + signatureLimit = 1024 + revocationLimit = 10 * 1024 * 1024 +) + +type ArtifactVerify struct { + rootKeys []PublicKey + keysBaseURL *url.URL + + revocationList *RevocationList +} + +func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) { + allKeys, err := loadEmbeddedPublicKeys() + if err != nil { + return nil, err + } + + return newArtifactVerify(keysBaseURL, allKeys) +} + +func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) { + ku, err := url.Parse(keysBaseURL) + if err != nil { + return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err) + } + + a := &ArtifactVerify{ + rootKeys: allKeys, + keysBaseURL: ku, + } + return a, nil +} + +func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error { + version = strings.TrimPrefix(version, "v") + + revocationList, err := a.loadRevocationList(ctx) + if err != nil { + return fmt.Errorf("failed to load revocation list: %v", err) + } + a.revocationList = revocationList + + artifactPubKeys, err := a.loadArtifactKeys(ctx) + if err != nil { + return fmt.Errorf("failed to load artifact keys: %v", err) + } + + signature, err := a.loadArtifactSignature(ctx, version, artifactFile) + if err != nil { + return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + log.Errorf("failed to read artifact file: %v", err) + return fmt.Errorf("failed to read artifact file: %w", err) + } + + if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil { + return fmt.Errorf("failed to validate artifact: %v", err) + } + + return nil +} + +func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String() + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse revocation list signature: %s", err) + return nil, err + } + + return ValidateRevocationList(a.rootKeys, data, *signature) +} + +func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String() + log.Debugf("starting downloading artifact keys from: %s", downloadURL) + data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit) + if err != nil { + log.Debugf("failed to download artifact keys: %s", err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String() + log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL) + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download signature of public keys: %s", err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse signature of public keys: %s", err) + return nil, err + } + + return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList) +} + +func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) { + artifactFile = filepath.Base(artifactFile) + downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download artifact signature: %s", err) + return nil, err + } + + signature, err := ParseSignature(data) + if err != nil { + log.Debugf("failed to parse artifact signature: %s", err) + return nil, err + } + + return signature, nil + +} + +func loadEmbeddedPublicKeys() ([]PublicKey, error) { + files, err := embeddedCerts.ReadDir(embeddedCertsDir) + if err != nil { + return nil, fmt.Errorf("failed to read embedded certs: %w", err) + } + + var allKeys []PublicKey + for _, file := range files { + if file.IsDir() { + continue + } + + data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name()) + if err != nil { + return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err) + } + + keys, err := parsePublicKeyBundle(data, tagRootPublic) + if err != nil { + return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err) + } + + allKeys = append(allKeys, keys...) + } + + if len(allKeys) == 0 { + return nil, fmt.Errorf("no valid public keys found in embedded certs") + } + + return allKeys, nil +} diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updatemanager/reposign/verify_test.go new file mode 100644 index 000000000..c29393bad --- /dev/null +++ b/client/internal/updatemanager/reposign/verify_test.go @@ -0,0 +1,528 @@ +package reposign + +import ( + "context" + "crypto/ed25519" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactVerify construction + +func TestArtifactVerify_Construction(t *testing.T) { + // Generate test root key + rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + assert.NotNil(t, av) + assert.NotEmpty(t, av.rootKeys) + assert.Equal(t, keysBaseURL, av.keysBaseURL.String()) + + // Verify root key structure + assert.NotEmpty(t, av.rootKeys[0].Key) + assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID) + assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero()) +} + +func TestArtifactVerify_MultipleRootKeys(t *testing.T) { + // Generate multiple test root keys + rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic) + require.NoError(t, err) + + rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2}) + assert.NoError(t, err) + assert.Len(t, av.rootKeys, 2) + assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID) +} + +// Test Verify workflow with mock HTTP server + +func TestArtifactVerify_FullWorkflow(t *testing.T) { + // Create temporary test directory + tempDir := t.TempDir() + + // Step 1: Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Step 2: Generate artifact key + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Step 3: Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Bundle artifact keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Step 5: Create test artifact + artifactPath := filepath.Join(tempDir, "test-artifact.bin") + artifactData := []byte("This is test artifact data for verification") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + // Step 6: Sign artifact + artifactSigData, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 7: Setup mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifacts/v1.0.0/test-artifact.bin": + _, _ = w.Write(artifactData) + case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + // Step 8: Create ArtifactVerify with test root key + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Step 9: Verify artifact + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.NoError(t, err) +} + +func TestArtifactVerify_InvalidRevocationList(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Setup server with invalid revocation list + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write([]byte("invalid data")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load revocation list") +} + +func TestArtifactVerify_MissingArtifactFile(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Create signature for non-existent file + testData := []byte("test") + artifactSigData, err := SignData(*artifactKey, testData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/missing.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", "file.bin") + assert.Error(t, err) +} + +func TestArtifactVerify_ServerUnavailable(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Use URL that doesn't exist + av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_ContextCancellation(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Create a server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Create context that cancels quickly + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_WithRevocation(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + _, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by revoked key + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey1, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because the signing key is revoked + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestArtifactVerify_ValidWithSecondKey(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + _, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by second key (not revoked) + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey2, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should succeed because second key is not revoked + assert.NoError(t, err) +} + +func TestArtifactVerify_TamperedArtifact(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key and artifact key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + artifactSigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + // Write tampered data to file + artifactPath := filepath.Join(tempDir, "test.bin") + tamperedData := []byte("tampered data") + err = os.WriteFile(artifactPath, tamperedData, 0644) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because artifact was tampered + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate artifact") +} + +// Test URL validation + +func TestArtifactVerify_URLParsing(t *testing.T) { + tests := []struct { + name string + keysBaseURL string + expectError bool + }{ + { + name: "Valid HTTP URL", + keysBaseURL: "http://example.com/artifact-signatures", + expectError: false, + }, + { + name: "Valid HTTPS URL", + keysBaseURL: "https://example.com/artifact-signatures", + expectError: false, + }, + { + name: "URL with port", + keysBaseURL: "http://localhost:8080/artifact-signatures", + expectError: false, + }, + { + name: "Invalid URL", + keysBaseURL: "://invalid", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newArtifactVerify(tt.keysBaseURL, nil) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updatemanager/update.go new file mode 100644 index 000000000..875b50b49 --- /dev/null +++ b/client/internal/updatemanager/update.go @@ -0,0 +1,11 @@ +package updatemanager + +import v "github.com/hashicorp/go-version" + +type UpdateInterface interface { + StopWatch() + SetDaemonVersion(newVersion string) bool + SetOnUpdateListener(updateFn func()) + LatestVersion() *v.Version + StartFetcher() +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 463c93d57..f3458ccea 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -131,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } diff --git a/client/netbird.wxs b/client/netbird.wxs index ba827debf..03221dd91 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -51,7 +51,7 @@ - + diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 28e8b2d4e..5ae0c1ad1 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v3.21.12 // source: daemon.proto package proto @@ -893,6 +893,7 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` + AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string { return "" } +func (x *UpRequest) GetAutoUpdate() bool { + if x != nil && x.AutoUpdate != nil { + return *x.AutoUpdate + } + return false +} + type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5356,6 +5364,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { return 0 } +type InstallerResultRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultRequest) Reset() { + *x = InstallerResultRequest{} + mi := &file_daemon_proto_msgTypes[79] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultRequest) ProtoMessage() {} + +func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[79] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. +func (*InstallerResultRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{79} +} + +type InstallerResultResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultResponse) Reset() { + *x = InstallerResultResponse{} + mi := &file_daemon_proto_msgTypes[80] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultResponse) ProtoMessage() {} + +func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[80] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. +func (*InstallerResultResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{80} +} + +func (x *InstallerResultResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *InstallerResultResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5366,7 +5462,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5378,7 +5474,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5502,12 +5598,16 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"p\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + + "\n" + + "autoUpdate\x18\x03 \x01(\bH\x02R\n" + + "autoUpdate\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_username\"\f\n" + + "\t_usernameB\r\n" + + "\v_autoUpdate\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -5893,7 +5993,11 @@ const file_daemon_proto_rawDesc = "" + "\x14WaitJWTTokenResponse\x12\x14\n" + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + - "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" + + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" + + "\x16InstallerResultRequest\"O\n" + + "\x17InstallerResultResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -5902,7 +6006,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xdb\x12\n" + + "\x05TRACE\x10\a2\xb4\x13\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5938,7 +6042,8 @@ const file_daemon_proto_rawDesc = "" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + - "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5953,7 +6058,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType @@ -6038,19 +6143,21 @@ var file_daemon_proto_goTypes = []any{ (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse - nil, // 83: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range - nil, // 85: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 86: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp + (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse + nil, // 85: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range + nil, // 87: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 88: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6061,8 +6168,8 @@ var file_daemon_proto_depIdxs = []int32{ 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6073,10 +6180,10 @@ var file_daemon_proto_depIdxs = []int32{ 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest @@ -6111,40 +6218,42 @@ var file_daemon_proto_depIdxs = []int32{ 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 66, // [66:98] is the sub-list for method output_type - 34, // [34:66] is the sub-list for method input_type + 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 67, // [67:100] is the sub-list for method output_type + 34, // [34:67] is the sub-list for method input_type 34, // [34:34] is the sub-list for extension type_name 34, // [34:34] is the sub-list for extension extendee 0, // [0:34] is the sub-list for field type_name @@ -6174,7 +6283,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 4, - NumMessages: 82, + NumMessages: 84, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 3dfd3da8d..5f30bfe4b 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -95,6 +95,8 @@ service DaemonService { rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} + + rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} } @@ -215,6 +217,7 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; + optional bool autoUpdate = 3; } message UpResponse {} @@ -772,3 +775,11 @@ message WaitJWTTokenResponse { // expiration time in seconds int64 expiresIn = 3; } + +message InstallerResultRequest { +} + +message InstallerResultResponse { + bool success = 1; + string errorMsg = 2; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 6b01309b7..fdabb1879 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -71,6 +71,7 @@ type DaemonServiceClient interface { // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) + GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) } type daemonServiceClient struct { @@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec return out, nil } +func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) { + out := new(InstallerResultResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -449,6 +459,7 @@ type DaemonServiceServer interface { // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) + GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") } +func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InstallerResultRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetInstallerResult", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "NotifyOSLifecycle", Handler: _DaemonService_NotifyOSLifecycle_Handler, }, + { + MethodName: "GetInstallerResult", + Handler: _DaemonService_GetInstallerResult_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/proto/generate.sh b/client/proto/generate.sh index f9a2c3750..e659cef90 100755 --- a/client/proto/generate.sh +++ b/client/proto/generate.sh @@ -14,4 +14,4 @@ cd "$script_path" go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional -cd "$old_pwd" \ No newline at end of file +cd "$old_pwd" diff --git a/client/server/server.go b/client/server/server.go index d33595115..99da4e36f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -192,7 +192,7 @@ func (s *Server) Start() error { s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -223,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -231,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -260,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) + doInitialAutoUpdate = false if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -728,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + var doAutoUpdate bool + if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { + doAutoUpdate = true + } + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) return s.waitForUp(callerCtx) } @@ -1539,9 +1545,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) if err := s.connectClient.Run(runningChan); err != nil { return err diff --git a/client/server/server_test.go b/client/server/server_test.go index 5f28a2664..69b4453ea 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } diff --git a/client/server/updateresult.go b/client/server/updateresult.go new file mode 100644 index 000000000..8e00d5062 --- /dev/null +++ b/client/server/updateresult.go @@ -0,0 +1,30 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/proto" +) + +func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) { + inst := installer.New() + dir := inst.TempDir() + + rh := installer.NewResultHandler(dir) + result, err := rh.Watch(ctx) + if err != nil { + log.Errorf("failed to watch update result: %v", err) + return &proto.InstallerResultResponse{ + Success: false, + ErrorMsg: err.Error(), + }, nil + } + + return &proto.InstallerResultResponse{ + Success: result.Success, + ErrorMsg: result.Error, + }, nil +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 8f99608e7..c786d3a61 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -34,6 +34,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" @@ -43,7 +44,6 @@ import ( "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/client/ui/process" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -87,22 +87,24 @@ func main() { // Create the service client (this also builds the settings or networks UI if requested). client := newServiceClient(&newServiceClientArgs{ - addr: flags.daemonAddr, - logFile: logFile, - app: a, - showSettings: flags.showSettings, - showNetworks: flags.showNetworks, - showLoginURL: flags.showLoginURL, - showDebug: flags.showDebug, - showProfiles: flags.showProfiles, - showQuickActions: flags.showQuickActions, + addr: flags.daemonAddr, + logFile: logFile, + app: a, + showSettings: flags.showSettings, + showNetworks: flags.showNetworks, + showLoginURL: flags.showLoginURL, + showDebug: flags.showDebug, + showProfiles: flags.showProfiles, + showQuickActions: flags.showQuickActions, + showUpdate: flags.showUpdate, + showUpdateVersion: flags.showUpdateVersion, }) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions { + if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate { a.Run() return } @@ -128,15 +130,17 @@ func main() { } type cliFlags struct { - daemonAddr string - showSettings bool - showNetworks bool - showProfiles bool - showDebug bool - showLoginURL bool - showQuickActions bool - errorMsg string - saveLogsInFile bool + daemonAddr string + showSettings bool + showNetworks bool + showProfiles bool + showDebug bool + showLoginURL bool + showQuickActions bool + errorMsg string + saveLogsInFile bool + showUpdate bool + showUpdateVersion string } // parseFlags reads and returns all needed command-line flags. @@ -156,6 +160,8 @@ func parseFlags() *cliFlags { flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window") flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window") + flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window") + flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to") flag.Parse() return &flags } @@ -319,6 +325,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + wUpdateProgress fyne.Window + updateContextCancel context.CancelFunc connectCancel context.CancelFunc } @@ -329,15 +337,17 @@ type menuHandler struct { } type newServiceClientArgs struct { - addr string - logFile string - app fyne.App - showSettings bool - showNetworks bool - showDebug bool - showLoginURL bool - showProfiles bool - showQuickActions bool + addr string + logFile string + app fyne.App + showSettings bool + showNetworks bool + showDebug bool + showLoginURL bool + showProfiles bool + showQuickActions bool + showUpdate bool + showUpdateVersion string } // newServiceClient instance constructor @@ -355,7 +365,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdate("nb/client-ui"), + update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -375,6 +385,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { s.showProfilesUI() case args.showQuickActions: s.showQuickActionsUI() + case args.showUpdate: + s.showUpdateProgress(ctx, args.showUpdateVersion) } return s @@ -814,7 +826,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log return nil } -func (s *serviceClient) menuUpClick(ctx context.Context) error { +func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -836,7 +848,9 @@ func (s *serviceClient) menuUpClick(ctx context.Context) error { return nil } - if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ + AutoUpdate: protobuf.Bool(wannaAutoUpdate), + }); err != nil { return fmt.Errorf("start connection: %w", err) } @@ -1097,6 +1111,26 @@ func (s *serviceClient) onTrayReady() { s.updateExitNodes() } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + // todo use new Category + if windowAction, ok := event.Metadata["progress_window"]; ok { + targetVersion, ok := event.Metadata["version"] + if !ok { + targetVersion = "unknown" + } + log.Debugf("window action: %v", windowAction) + if windowAction == "show" { + if s.updateContextCancel != nil { + s.updateContextCancel() + s.updateContextCancel = nil + } + + subCtx, cancel := context.WithCancel(s.ctx) + go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion) + s.updateContextCancel = cancel + } + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e0b619411..9ffacd926 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() { go func() { defer connectCancel() - if err := h.client.menuUpClick(connectCtx); err != nil { + if err := h.client.menuUpClick(connectCtx, true); err != nil { st, ok := status.FromError(err) if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") @@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() { go func() { defer h.client.mAdvancedSettings.Enable() defer h.client.getSrvConfig() - h.runSelfCommand(h.client.ctx, "settings", "true") + h.runSelfCommand(h.client.ctx, "settings") }() } @@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() { h.client.mCreateDebugBundle.Disable() go func() { defer h.client.mCreateDebugBundle.Enable() - h.runSelfCommand(h.client.ctx, "debug", "true") + h.runSelfCommand(h.client.ctx, "debug") }() } @@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() { h.client.mNetworks.Disable() go func() { defer h.client.mNetworks.Enable() - h.runSelfCommand(h.client.ctx, "networks", "true") + h.runSelfCommand(h.client.ctx, "networks") }() } @@ -237,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error { return nil } -func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) { +func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) { proc, err := os.Executable() if err != nil { log.Errorf("error getting executable path: %v", err) return } - cmd := exec.CommandContext(ctx, proc, - fmt.Sprintf("--%s=%s", command, arg), + // Build the full command arguments + cmdArgs := []string{ + fmt.Sprintf("--%s=true", command), fmt.Sprintf("--daemon-addr=%s", h.client.addr), - ) + } + cmdArgs = append(cmdArgs, args...) + + cmd := exec.CommandContext(ctx, proc, cmdArgs...) if out := h.client.attachOutput(cmd); out != nil { defer func() { @@ -257,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) }() } - log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr) + log.Printf("running command: %s", cmd.String()) if err := cmd.Run(); err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode()) + log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode()) } return } - log.Printf("command '%s %s' completed successfully", command, arg) + log.Printf("command '%s' completed successfully", cmd.String()) } func (h *eventHandler) logout(ctx context.Context) error { diff --git a/client/ui/profile.go b/client/ui/profile.go index 74189c9a0..a38d8918a 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -397,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func(context.Context) error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -411,7 +411,7 @@ type newProfileMenuArgs struct { profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func(context.Context) error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -579,7 +579,7 @@ func (p *profileMenu) refresh() { connectCtx, connectCancel := context.WithCancel(p.ctx) p.serviceClient.connectCancel = connectCancel - if err := p.upClickCallback(connectCtx); err != nil { + if err := p.upClickCallback(connectCtx, false); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go index bf47ac434..76440d684 100644 --- a/client/ui/quickactions.go +++ b/client/ui/quickactions.go @@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() { connCmd := connectCommand{ connectClient: func() error { - return s.menuUpClick(s.ctx) + return s.menuUpClick(s.ctx, false) }, } diff --git a/client/ui/update.go b/client/ui/update.go new file mode 100644 index 000000000..25c317bdf --- /dev/null +++ b/client/ui/update.go @@ -0,0 +1,140 @@ +//go:build !(linux && 386) + +package main + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) { + log.Infof("show installer progress window: %s", version) + s.wUpdateProgress = s.app.NewWindow("Automatically updating client") + + statusLabel := widget.NewLabel("Updating...") + infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version)) + content := container.NewVBox(infoLabel, statusLabel) + s.wUpdateProgress.SetContent(content) + s.wUpdateProgress.CenterOnScreen() + s.wUpdateProgress.SetFixedSize(true) + s.wUpdateProgress.SetCloseIntercept(func() { + // this is empty to lock window until result known + }) + s.wUpdateProgress.RequestFocus() + s.wUpdateProgress.Show() + + updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + + // Initialize dot updater + updateText := dotUpdater() + + // Channel to receive the result from RPC call + resultErrCh := make(chan error, 1) + resultOkCh := make(chan struct{}, 1) + + // Start RPC call in background + go func() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Infof("backend not reachable, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{}) + if err != nil { + log.Infof("backend stopped responding, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + if !resp.Success { + resultErrCh <- mapInstallError(resp.ErrorMsg) + return + } + + // Success + close(resultOkCh) + }() + + // Update UI with dots and wait for result + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + defer cancel() + + // allow closing update window after 10 sec + timerResetCloseInterceptor := time.NewTimer(10 * time.Second) + defer timerResetCloseInterceptor.Stop() + + for { + select { + case <-updateWindowCtx.Done(): + s.showInstallerResult(statusLabel, updateWindowCtx.Err()) + return + case err := <-resultErrCh: + s.showInstallerResult(statusLabel, err) + return + case <-resultOkCh: + log.Info("backend exited, upgrade in progress, closing all UI") + killParentUIProcess() + s.app.Quit() + return + case <-ticker.C: + statusLabel.SetText(updateText()) + case <-timerResetCloseInterceptor.C: + s.wUpdateProgress.SetCloseIntercept(nil) + } + } + }() +} + +func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) { + s.wUpdateProgress.SetCloseIntercept(nil) + switch { + case errors.Is(err, context.DeadlineExceeded): + log.Warn("update watcher timed out") + statusLabel.SetText("Update timed out. Please try again.") + case errors.Is(err, context.Canceled): + log.Info("update watcher canceled") + statusLabel.SetText("Update canceled.") + case err != nil: + log.Errorf("update failed: %v", err) + statusLabel.SetText("Update failed: " + err.Error()) + default: + s.wUpdateProgress.Close() + } +} + +// dotUpdater returns a closure that cycles through dots for a loading animation. +func dotUpdater() func() string { + dotCount := 0 + return func() string { + dotCount = (dotCount + 1) % 4 + return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount)) + } +} + +func mapInstallError(msg string) error { + msg = strings.ToLower(strings.TrimSpace(msg)) + + switch { + case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"): + return context.DeadlineExceeded + case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"): + return context.Canceled + case msg == "": + return errors.New("unknown update error") + default: + return errors.New(msg) + } +} diff --git a/client/ui/update_notwindows.go b/client/ui/update_notwindows.go new file mode 100644 index 000000000..5766f18f7 --- /dev/null +++ b/client/ui/update_notwindows.go @@ -0,0 +1,7 @@ +//go:build !windows && !(linux && 386) + +package main + +func killParentUIProcess() { + // No-op on non-Windows platforms +} diff --git a/client/ui/update_windows.go b/client/ui/update_windows.go new file mode 100644 index 000000000..1b03936f9 --- /dev/null +++ b/client/ui/update_windows.go @@ -0,0 +1,44 @@ +//go:build windows + +package main + +import ( + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + + nbprocess "github.com/netbirdio/netbird/client/ui/process" +) + +// killParentUIProcess finds and kills the parent systray UI process on Windows. +// This is a workaround in case the MSI installer fails to properly terminate the UI process. +// The installer should handle this via util:CloseApplication with TerminateProcess, but this +// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds. +func killParentUIProcess() { + pid, running, err := nbprocess.IsAnotherProcessRunning() + if err != nil { + log.Warnf("failed to check for parent UI process: %v", err) + return + } + + if !running { + log.Debug("no parent UI process found to kill") + return + } + + log.Infof("killing parent UI process (PID: %d)", pid) + + // Open the process with terminate rights + handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid)) + if err != nil { + log.Warnf("failed to open parent process %d: %v", pid, err) + return + } + defer func() { + _ = windows.CloseHandle(handle) + }() + + // Terminate the process with exit code 0 + if err := windows.TerminateProcess(handle, 0); err != nil { + log.Warnf("failed to terminate parent process %d: %v", pid, err) + } +} diff --git a/management/internals/server/server.go b/management/internals/server/server.go index a1b144dac..d9c715225 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) - s.update = version.NewUpdate("nb/management") + s.update = version.NewUpdateAndStart("nb/management") s.update.SetDaemonVersion(version.NetbirdVersion()) s.update.SetOnUpdateListener(func() { log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion()) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 2b15fe4b8..1a0dc009b 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -7,6 +7,7 @@ import ( "strings" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -101,6 +102,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set Fqdn: fqdn, RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, + AutoUpdate: &proto.AutoUpdateSettings{ + Version: settings.AutoUpdateVersion, + }, } } @@ -108,9 +112,10 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), }, Checks: toProtocolChecks(ctx, checks), } diff --git a/management/server/account.go b/management/server/account.go index a9becc4b6..04f83842e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -321,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || - oldSettings.DNSDomain != newSettings.DNSDomain { + oldSettings.DNSDomain != newSettings.DNSDomain || + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { updateAccountPeers = true } @@ -360,6 +361,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } @@ -451,6 +453,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{ + "version": newSettings.AutoUpdateVersion, + }) + } +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 2e3be1ef5..6344b2904 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -181,6 +181,8 @@ const ( UserRejected Activity = 90 UserCreated Activity = 91 + AccountAutoUpdateVersionUpdated Activity = 92 + AccountDeleted Activity = 99999 ) @@ -287,9 +289,12 @@ var activityMap = map[Activity]Code{ AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, - UserApproved: {"User approved", "user.approve"}, - UserRejected: {"User rejected", "user.reject"}, - UserCreated: {"User created", "user.create"}, + + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, + UserCreated: {"User created", "user.create"}, + + AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index f1552d0ea..3797b0512 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -3,12 +3,15 @@ package accounts import ( "context" "encoding/json" + "fmt" "net/http" "net/netip" "time" "github.com/gorilla/mux" + goversion "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/settings" @@ -26,7 +29,9 @@ const ( // MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16) MinNetworkBitsIPv4 = 28 // MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges - MinNetworkBitsIPv6 = 120 + MinNetworkBitsIPv6 = 120 + disableAutoUpdate = "disabled" + autoUpdateLatestVersion = "latest" ) // handler is a handler that handles the server.Account HTTP endpoints @@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } +func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) { + returnSettings := &types.Settings{ + PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), + } + + if req.Settings.Extra != nil { + returnSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, + FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, + FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, + FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, + } + } + + if req.Settings.JwtGroupsEnabled != nil { + returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled + } + if req.Settings.GroupsPropagationEnabled != nil { + returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled + } + if req.Settings.JwtGroupsClaimName != nil { + returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName + } + if req.Settings.JwtAllowGroups != nil { + returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups + } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } + if req.Settings.DnsDomain != nil { + returnSettings.DNSDomain = *req.Settings.DnsDomain + } + if req.Settings.LazyConnectionEnabled != nil { + returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + } + if req.Settings.AutoUpdateVersion != nil { + _, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion) + if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion || + *req.Settings.AutoUpdateVersion == disableAutoUpdate || + err == nil { + returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion + } else if *req.Settings.AutoUpdateVersion != "" { + return nil, fmt.Errorf("invalid AutoUpdateVersion") + } + } + + return returnSettings, nil +} + // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) @@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - settings := &types.Settings{ - PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, - PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), - RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), - } - - if req.Settings.Extra != nil { - settings.Extra = &types.ExtraSettings{ - PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, - UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, - FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, - FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, - FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, - } - } - - if req.Settings.JwtGroupsEnabled != nil { - settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled - } - if req.Settings.GroupsPropagationEnabled != nil { - settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled - } - if req.Settings.JwtGroupsClaimName != nil { - settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName - } - if req.Settings.JwtAllowGroups != nil { - settings.JWTAllowGroups = *req.Settings.JwtAllowGroups - } - if req.Settings.RoutingPeerDnsResolutionEnabled != nil { - settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled - } - if req.Settings.DnsDomain != nil { - settings.DNSDomain = *req.Settings.DnsDomain - } - if req.Settings.LazyConnectionEnabled != nil { - settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + settings, err := h.updateAccountRequestSettings(req) + if err != nil { + util.WriteError(r.Context(), err, w) + return } if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" { prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange) @@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, + AutoUpdateVersion: &settings.AutoUpdateVersion, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index c5c48ef32..2e48ac83e 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: true, expectedID: accountID, @@ -143,6 +144,30 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + }, + expectedArray: false, + expectedID: accountID, + }, + { + name: "PutAccount OK with autoUpdateVersion", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + AutoUpdateVersion: sr("latest"), }, expectedArray: false, expectedID: accountID, @@ -165,6 +190,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, @@ -187,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, @@ -209,6 +236,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/types/settings.go b/management/server/types/settings.go index b4afb2f5e..867e12bef 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -52,6 +52,9 @@ type Settings struct { // LazyConnectionEnabled indicates if the experimental feature is enabled or disabled LazyConnectionEnabled bool `gorm:"default:false"` + + // AutoUpdateVersion client auto-update version + AutoUpdateVersion string `gorm:"default:'disabled'"` } // Copy copies the Settings struct @@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings { LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, + AutoUpdateVersion: s.AutoUpdateVersion, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 2d063a7b5..42f2fd89e 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -145,6 +145,10 @@ components: description: Enables or disables experimental lazy connection type: boolean example: true + auto_update_version: + description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + type: string + example: "0.51.2" required: - peer_login_expiration_enabled - peer_login_expiration diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d3e425548..74d9ca92b 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -291,6 +291,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + AutoUpdateVersion *string `json:"auto_update_version,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account DnsDomain *string `json:"dns_domain,omitempty"` Extra *AccountExtraSettings `json:"extra,omitempty"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2e4cf2644..217b33db9 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.1 +// protoc v3.21.12 // source: management.proto package proto @@ -267,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24, 0} + return file_management_proto_rawDescGZIP(), []int{25, 0} } type EncryptedMessage struct { @@ -1762,6 +1762,8 @@ type PeerConfig struct { RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"` Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` + // Auto-update config + AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"` } func (x *PeerConfig) Reset() { @@ -1845,6 +1847,70 @@ func (x *PeerConfig) GetMtu() int32 { return 0 } +func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings { + if x != nil { + return x.AutoUpdate + } + return nil +} + +type AutoUpdateSettings struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` + // alwaysUpdate = true → Updates happen automatically in the background + // alwaysUpdate = false → Updates only happen when triggered by a peer connection + AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"` +} + +func (x *AutoUpdateSettings) Reset() { + *x = AutoUpdateSettings{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AutoUpdateSettings) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AutoUpdateSettings) ProtoMessage() {} + +func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[20] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AutoUpdateSettings.ProtoReflect.Descriptor instead. +func (*AutoUpdateSettings) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{20} +} + +func (x *AutoUpdateSettings) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *AutoUpdateSettings) GetAlwaysUpdate() bool { + if x != nil { + return x.AlwaysUpdate + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -1881,7 +1947,7 @@ type NetworkMap struct { func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1894,7 +1960,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1907,7 +1973,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *NetworkMap) GetSerial() uint64 { @@ -2015,7 +2081,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2028,7 +2094,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2041,7 +2107,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{22} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -2096,7 +2162,7 @@ type SSHConfig struct { func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2109,7 +2175,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2122,7 +2188,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{23} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2156,7 +2222,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2169,7 +2235,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2182,7 +2248,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{24} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2201,7 +2267,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2214,7 +2280,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2227,7 +2293,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2254,7 +2320,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2267,7 +2333,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2280,7 +2346,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{26} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2297,7 +2363,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2310,7 +2376,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2323,7 +2389,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2369,7 +2435,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2382,7 +2448,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2395,7 +2461,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{28} } func (x *ProviderConfig) GetClientID() string { @@ -2503,7 +2569,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2516,7 +2582,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2529,7 +2595,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *Route) GetID() string { @@ -2618,7 +2684,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2631,7 +2697,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2644,7 +2710,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2691,7 +2757,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2704,7 +2770,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2717,7 +2783,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *CustomZone) GetDomain() string { @@ -2764,7 +2830,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2777,7 +2843,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2790,7 +2856,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *SimpleRecord) GetName() string { @@ -2843,7 +2909,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2856,7 +2922,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2869,7 +2935,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2914,7 +2980,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2927,7 +2993,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2940,7 +3006,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *NameServer) GetIP() string { @@ -2983,7 +3049,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2996,7 +3062,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3009,7 +3075,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *FirewallRule) GetPeerIP() string { @@ -3073,7 +3139,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3086,7 +3152,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3099,7 +3165,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *NetworkAddress) GetNetIP() string { @@ -3127,7 +3193,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3140,7 +3206,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3153,7 +3219,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{37} } func (x *Checks) GetFiles() []string { @@ -3178,7 +3244,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3191,7 +3257,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3204,7 +3270,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{38} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3275,7 +3341,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3288,7 +3354,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3301,7 +3367,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3392,7 +3458,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3405,7 +3471,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3418,7 +3484,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{39} + return file_management_proto_rawDescGZIP(), []int{40} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3461,7 +3527,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3474,7 +3540,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[41] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3487,7 +3553,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37, 0} + return file_management_proto_rawDescGZIP(), []int{38, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3747,7 +3813,7 @@ var file_management_proto_rawDesc = []byte{ 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, - 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, @@ -3764,308 +3830,317 @@ var file_management_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, - 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, - 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, - 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, - 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, - 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, - 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, - 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, - 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, - 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, - 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, - 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, - 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, - 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, - 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, - 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, - 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, - 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, - 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, - 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, - 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, - 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, - 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, - 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, - 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, - 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, - 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, - 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, - 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, - 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, - 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, - 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, - 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, - 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, - 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, - 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, - 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, - 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, - 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, - 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, - 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, - 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, - 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, - 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, - 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, - 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, - 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, - 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, - 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, - 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, - 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, - 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, - 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, - 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, - 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, - 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, - 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, - 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, - 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, - 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, - 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, - 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, - 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, - 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, - 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, + 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, + 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, + 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, + 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, + 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, + 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, + 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, + 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, + 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, + 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, + 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, + 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, + 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, + 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, + 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, + 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, + 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, + 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, + 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, + 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, + 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, + 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, + 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, + 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, + 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, + 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, + 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, + 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, + 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, + 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, + 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, + 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, + 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, + 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, + 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, + 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, + 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, + 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, + 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, + 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, + 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, + 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, + 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, + 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, + 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, + 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, + 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, + 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, + 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, - 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, - 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, - 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, + 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, + 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4081,7 +4156,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 41) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 42) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -4108,106 +4183,108 @@ var file_management_proto_goTypes = []interface{}{ (*JWTConfig)(nil), // 22: management.JWTConfig (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig (*PeerConfig)(nil), // 24: management.PeerConfig - (*NetworkMap)(nil), // 25: management.NetworkMap - (*RemotePeerConfig)(nil), // 26: management.RemotePeerConfig - (*SSHConfig)(nil), // 27: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 28: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 29: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 30: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 31: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 32: management.ProviderConfig - (*Route)(nil), // 33: management.Route - (*DNSConfig)(nil), // 34: management.DNSConfig - (*CustomZone)(nil), // 35: management.CustomZone - (*SimpleRecord)(nil), // 36: management.SimpleRecord - (*NameServerGroup)(nil), // 37: management.NameServerGroup - (*NameServer)(nil), // 38: management.NameServer - (*FirewallRule)(nil), // 39: management.FirewallRule - (*NetworkAddress)(nil), // 40: management.NetworkAddress - (*Checks)(nil), // 41: management.Checks - (*PortInfo)(nil), // 42: management.PortInfo - (*RouteFirewallRule)(nil), // 43: management.RouteFirewallRule - (*ForwardingRule)(nil), // 44: management.ForwardingRule - (*PortInfo_Range)(nil), // 45: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 46: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 47: google.protobuf.Duration + (*AutoUpdateSettings)(nil), // 25: management.AutoUpdateSettings + (*NetworkMap)(nil), // 26: management.NetworkMap + (*RemotePeerConfig)(nil), // 27: management.RemotePeerConfig + (*SSHConfig)(nil), // 28: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 29: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 30: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 31: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 32: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 33: management.ProviderConfig + (*Route)(nil), // 34: management.Route + (*DNSConfig)(nil), // 35: management.DNSConfig + (*CustomZone)(nil), // 36: management.CustomZone + (*SimpleRecord)(nil), // 37: management.SimpleRecord + (*NameServerGroup)(nil), // 38: management.NameServerGroup + (*NameServer)(nil), // 39: management.NameServer + (*FirewallRule)(nil), // 40: management.FirewallRule + (*NetworkAddress)(nil), // 41: management.NetworkAddress + (*Checks)(nil), // 42: management.Checks + (*PortInfo)(nil), // 43: management.PortInfo + (*RouteFirewallRule)(nil), // 44: management.RouteFirewallRule + (*ForwardingRule)(nil), // 45: management.ForwardingRule + (*PortInfo_Range)(nil), // 46: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 48: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 26, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 25, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 41, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 27, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 26, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 42, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 40, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 41, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 41, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 46, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 42, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 47, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 47, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 48, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 27, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 24, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 26, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 33, // 28: management.NetworkMap.Routes:type_name -> management.Route - 34, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 26, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 39, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 43, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 44, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 27, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 22, // 35: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 4, // 36: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 32, // 37: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 32, // 38: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 37, // 39: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 35, // 40: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 36, // 41: management.CustomZone.Records:type_name -> management.SimpleRecord - 38, // 42: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 43: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 44: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 45: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 42, // 46: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 45, // 47: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 48: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 49: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 42, // 50: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 51: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 42, // 52: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 42, // 53: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 54: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 55: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 56: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 57: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 58: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 59: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 63: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 64: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 65: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 66: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 67: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 68: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 69: management.ManagementService.Logout:output_type -> management.Empty - 62, // [62:70] is the sub-list for method output_type - 54, // [54:62] is the sub-list for method input_type - 54, // [54:54] is the sub-list for extension type_name - 54, // [54:54] is the sub-list for extension extendee - 0, // [0:54] is the sub-list for field type_name + 28, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 25, // 26: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 27, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 34, // 29: management.NetworkMap.Routes:type_name -> management.Route + 35, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 27, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 40, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 44, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 45, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 28, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 33, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 33, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 38, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 36, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 37, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord + 39, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 43, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 46, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 43, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 43, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 43, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 70: management.ManagementService.Logout:output_type -> management.Empty + 63, // [63:71] is the sub-list for method output_type + 55, // [55:63] is the sub-list for method input_type + 55, // [55:55] is the sub-list for extension type_name + 55, // [55:55] is the sub-list for extension extendee + 0, // [0:55] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4457,7 +4534,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*AutoUpdateSettings); i { case 0: return &v.state case 1: @@ -4469,7 +4546,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -4481,7 +4558,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4493,7 +4570,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4505,7 +4582,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4517,7 +4594,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4529,7 +4606,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4541,7 +4618,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4553,7 +4630,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4565,7 +4642,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4577,7 +4654,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4589,7 +4666,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4601,7 +4678,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4613,7 +4690,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4625,7 +4702,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4637,7 +4714,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4649,7 +4726,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4661,7 +4738,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -4673,7 +4750,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*PortInfo); i { case 0: return &v.state case 1: @@ -4685,7 +4762,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*RouteFirewallRule); i { case 0: return &v.state case 1: @@ -4697,6 +4774,18 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4709,7 +4798,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[37].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[38].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4719,7 +4808,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 41, + NumMessages: 42, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index dc60b026d..fec51ca91 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -280,6 +280,18 @@ message PeerConfig { bool LazyConnectionEnabled = 6; int32 mtu = 7; + + // Auto-update config + AutoUpdateSettings autoUpdate = 8; +} + +message AutoUpdateSettings { + string version = 1; + /* + alwaysUpdate = true → Updates happen automatically in the background + alwaysUpdate = false → Updates only happen when triggered by a peer connection + */ + bool alwaysUpdate = 2; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections diff --git a/version/update.go b/version/update.go index 272eef4c6..a324d97fe 100644 --- a/version/update.go +++ b/version/update.go @@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update { currentVersion, _ = goversion.NewVersion("0.0.0") } - latestAvailable, _ := goversion.NewVersion("0.0.0") - u := &Update{ - httpAgent: httpAgent, - latestAvailable: latestAvailable, - uiVersion: currentVersion, - fetchTicker: time.NewTicker(fetchPeriod), - fetchDone: make(chan struct{}), + httpAgent: httpAgent, + uiVersion: currentVersion, + fetchDone: make(chan struct{}), } - go u.startFetcher() + + return u +} + +func NewUpdateAndStart(httpAgent string) *Update { + u := NewUpdate(httpAgent) + go u.StartFetcher() + return u } // StopWatch stop the version info fetch loop func (u *Update) StopWatch() { + if u.fetchTicker == nil { + return + } + u.fetchTicker.Stop() select { @@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) { } } -func (u *Update) startFetcher() { +func (u *Update) LatestVersion() *goversion.Version { + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + return u.latestAvailable +} + +func (u *Update) StartFetcher() { + if u.fetchTicker != nil { + return + } + u.fetchTicker = time.NewTicker(fetchPeriod) + if changed := u.fetchVersion(); changed { u.checkUpdate() } @@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool { u.versionsLock.Lock() defer u.versionsLock.Unlock() + if u.latestAvailable == nil { + return false + } + if u.latestAvailable.GreaterThan(u.uiVersion) { return true } diff --git a/version/update_test.go b/version/update_test.go index a733714cf..d5d60800e 100644 --- a/version/update_test.go +++ b/version/update_test.go @@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true From 433bc4ead9d9f798f5d6a2e3962be95b67261e55 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Dec 2025 20:04:52 +0100 Subject: [PATCH 107/118] [client] lookup for management domains using an additional timeout (#4983) in some cases iOS and macOS may be locked when looking for management domains during network changes This change introduce an additional timeout on top of the context call --- client/internal/dns/mgmt/mgmt.go | 40 ++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 290395473..d01be0c2c 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "net/url" "strings" "sync" @@ -26,6 +27,11 @@ type Resolver struct { mutex sync.RWMutex } +type ipsResponse struct { + ips []netip.Addr + err error +} + // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ @@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + ips, err := lookupIPWithExtraTimeout(ctx, d) if err != nil { - return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + return err } var aRecords, aaaaRecords []dns.RR @@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return nil } +func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { + log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) + defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) + resultChan := make(chan *ipsResponse, 1) + + go func() { + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + resultChan <- &ipsResponse{ + err: err, + ips: ips, + } + }() + + var resp *ipsResponse + + select { + case <-time.After(dnsTimeout + time.Millisecond*500): + log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) + return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) + case <-ctx.Done(): + return nil, ctx.Err() + case resp = <-resultChan: + } + + if resp.err != nil { + return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + } + return resp.ips, nil +} + // PopulateFromConfig extracts and caches domains from the client configuration. func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { if mgmtURL == nil { From b7e98acd1f5f329366864d4add4c8350750c4d96 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 22 Dec 2025 22:09:05 +0100 Subject: [PATCH 108/118] [client] Android profile switch (#4884) Expose the profile-manager service for Android. Logout was not part of the manager service implementation. In the future, I recommend moving this logic there. --- client/android/client.go | 30 ++- client/android/profile_manager.go | 257 ++++++++++++++++++++++ client/internal/connect.go | 12 +- client/internal/profilemanager/service.go | 29 ++- 4 files changed, 307 insertions(+), 21 deletions(-) create mode 100644 client/android/profile_manager.go diff --git a/client/android/client.go b/client/android/client.go index b16a23d64..ccf32a90c 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -59,7 +59,6 @@ func init() { // Client struct manage the life circle of background service type Client struct { - cfgFile string tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status @@ -68,18 +67,16 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener - stateFile string connectClient *internal.ConnectClient } // NewClient instantiate a new Client -func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: platformFiles.ConfigurationFilePath(), deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, networkChangeListener: networkChangeListener, - stateFile: platformFiles.StateFilePath(), } } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -123,15 +125,21 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -150,7 +158,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // Stop the internal client and free the resources diff --git a/client/android/profile_manager.go b/client/android/profile_manager.go new file mode 100644 index 000000000..60e4d5c32 --- /dev/null +++ b/client/android/profile_manager.go @@ -0,0 +1,257 @@ +//go:build android + +package android + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +const ( + // Android-specific config filename (different from desktop default.json) + defaultConfigFilename = "netbird.cfg" + // Subdirectory for non-default profiles (must match Java Preferences.java) + profilesSubdir = "profiles" + // Android uses a single user context per app (non-empty username required by ServiceManager) + androidUsername = "android" +) + +// Profile represents a profile for gomobile +type Profile struct { + Name string + IsActive bool +} + +// ProfileArray wraps profiles for gomobile compatibility +type ProfileArray struct { + items []*Profile +} + +// Length returns the number of profiles +func (p *ProfileArray) Length() int { + return len(p.items) +} + +// Get returns the profile at index i +func (p *ProfileArray) Get(i int) *Profile { + if i < 0 || i >= len(p.items) { + return nil + } + return p.items[i] +} + +/* + +/data/data/io.netbird.client/files/ ← configDir parameter +├── netbird.cfg ← Default profile config +├── state.json ← Default profile state +├── active_profile.json ← Active profile tracker (JSON with Name + Username) +└── profiles/ ← Subdirectory for non-default profiles + ├── work.json ← Work profile config + ├── work.state.json ← Work profile state + ├── personal.json ← Personal profile config + └── personal.state.json ← Personal profile state +*/ + +// ProfileManager manages profiles for Android +// It wraps the internal profilemanager to provide Android-specific behavior +type ProfileManager struct { + configDir string + serviceMgr *profilemanager.ServiceManager +} + +// NewProfileManager creates a new profile manager for Android +func NewProfileManager(configDir string) *ProfileManager { + // Set the default config path for Android (stored in root configDir, not profiles/) + defaultConfigPath := filepath.Join(configDir, defaultConfigFilename) + + // Set global paths for Android + profilemanager.DefaultConfigPathDir = configDir + profilemanager.DefaultConfigPath = defaultConfigPath + profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json") + + // Create ServiceManager with profiles/ subdirectory + // This avoids modifying the global ConfigDirOverride for profile listing + profilesDir := filepath.Join(configDir, profilesSubdir) + serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir) + + return &ProfileManager{ + configDir: configDir, + serviceMgr: serviceMgr, + } +} + +// ListProfiles returns all available profiles +func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) { + // Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive) + internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername) + if err != nil { + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + // Convert internal profiles to Android Profile type + var profiles []*Profile + for _, p := range internalProfiles { + profiles = append(profiles, &Profile{ + Name: p.Name, + IsActive: p.IsActive, + }) + } + + return &ProfileArray{items: profiles}, nil +} + +// GetActiveProfile returns the currently active profile name +func (pm *ProfileManager) GetActiveProfile() (string, error) { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + activeState, err := pm.serviceMgr.GetActiveProfileState() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return activeState.Name, nil +} + +// SwitchProfile switches to a different profile +func (pm *ProfileManager) SwitchProfile(profileName string) error { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profileName, + Username: androidUsername, + }) + if err != nil { + return fmt.Errorf("failed to switch profile: %w", err) + } + + log.Infof("switched to profile: %s", profileName) + return nil +} + +// AddProfile creates a new profile +func (pm *ProfileManager) AddProfile(profileName string) error { + // Use ServiceManager (creates profile in profiles/ directory) + if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to add profile: %w", err) + } + + log.Infof("created new profile: %s", profileName) + return nil +} + +// LogoutProfile logs out from a profile (clears authentication) +func (pm *ProfileManager) LogoutProfile(profileName string) error { + profileName = sanitizeProfileName(profileName) + + configPath, err := pm.getProfileConfigPath(profileName) + if err != nil { + return err + } + + // Check if profile exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return fmt.Errorf("profile '%s' does not exist", profileName) + } + + // Read current config using internal profilemanager + config, err := profilemanager.ReadConfig(configPath) + if err != nil { + return fmt.Errorf("failed to read profile config: %w", err) + } + + // Clear authentication by removing private key and SSH key + config.PrivateKey = "" + config.SSHKey = "" + + // Save config using internal profilemanager + if err := profilemanager.WriteOutConfig(configPath, config); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + log.Infof("logged out from profile: %s", profileName) + return nil +} + +// RemoveProfile deletes a profile +func (pm *ProfileManager) RemoveProfile(profileName string) error { + // Use ServiceManager (removes profile from profiles/ directory) + if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to remove profile: %w", err) + } + + log.Infof("removed profile: %s", profileName) + return nil +} + +// getProfileConfigPath returns the config file path for a profile +// This is needed for Android-specific path handling (netbird.cfg for default profile) +func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + // Android uses netbird.cfg for default profile instead of default.json + // Default profile is stored in root configDir, not in profiles/ + return filepath.Join(pm.configDir, defaultConfigFilename), nil + } + + // Non-default profiles are stored in profiles subdirectory + // This matches the Java Preferences.java expectation + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".json"), nil +} + +// GetConfigPath returns the config file path for a given profile +// Java should call this instead of constructing paths with Preferences.configFile() +func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) { + return pm.getProfileConfigPath(profileName) +} + +// GetStateFilePath returns the state file path for a given profile +// Java should call this instead of constructing paths with Preferences.stateFile() +func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + return filepath.Join(pm.configDir, "state.json"), nil + } + + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".state.json"), nil +} + +// GetActiveConfigPath returns the config file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile() +func (pm *ProfileManager) GetActiveConfigPath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetConfigPath(activeProfile) +} + +// GetActiveStateFilePath returns the state file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile() +func (pm *ProfileManager) GetActiveStateFilePath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetStateFilePath(activeProfile) +} + +// sanitizeProfileName removes invalid characters from profile name +func sanitizeProfileName(name string) string { + // Keep only alphanumeric, underscore, and hyphen + var result strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '_' || r == '-' { + result.WriteRune(r) + } + } + return result.String() +} diff --git a/client/internal/connect.go b/client/internal/connect.go index 79ec2f907..017c8bf10 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -170,19 +170,19 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return err } - sm := profilemanager.NewServiceManager("") - - path := sm.GetStatePath() + var path string if runtime.GOOS == "ios" || runtime.GOOS == "android" { + // On mobile, use the provided state file path directly if !fileExists(mobileDependency.StateFilePath) { - err := createFile(mobileDependency.StateFilePath) - if err != nil { + if err := createFile(mobileDependency.StateFilePath); err != nil { log.Errorf("failed to create state file: %v", err) // we are not exiting as we can run without the state manager } } - path = mobileDependency.StateFilePath + } else { + sm := profilemanager.NewServiceManager("") + path = sm.GetStatePath() } stateManager := statemanager.New(path) stateManager.RegisterState(&sshconfig.ShutdownState{}) diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go index faccf5f68..5a0c14000 100644 --- a/client/internal/profilemanager/service.go +++ b/client/internal/profilemanager/service.go @@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) { } type ServiceManager struct { + profilesDir string // If set, overrides ConfigDirOverride for profile operations } func NewServiceManager(defaultConfigPath string) *ServiceManager { @@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager { return &ServiceManager{} } +// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory +// This allows setting the profiles directory without modifying the global ConfigDirOverride +func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager { + if defaultConfigPath != "" { + DefaultConfigPath = defaultConfigPath + } + return &ServiceManager{ + profilesDir: profilesDir, + } +} + func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) { if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil { @@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string { } func (s *ServiceManager) AddProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error { } func (s *ServiceManager) RemoveProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error { } func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return nil, fmt.Errorf("failed to get config directory: %w", err) } @@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string { return defaultStatePath } - configDir, err := getConfigDirForUser(activeProf.Username) + configDir, err := s.getConfigDir(activeProf.Username) if err != nil { log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err) return defaultStatePath @@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string { return filepath.Join(configDir, activeProf.Name+".state.json") } + +// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser +func (s *ServiceManager) getConfigDir(username string) (string, error) { + if s.profilesDir != "" { + return s.profilesDir, nil + } + + return getConfigDirForUser(username) +} From fc4932a23fc72046eb44d5b136e0a963ca262019 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:06:13 +0100 Subject: [PATCH 109/118] [client] Fix Linux UI flickering on state updates (#4886) --- client/ui/client_ui.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index c786d3a61..87bac8c31 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -312,6 +312,8 @@ type serviceClient struct { daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + settingsEnabled bool + profilesEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -907,7 +909,7 @@ func (s *serviceClient) updateStatus() error { var systrayIconState bool switch { - case status.Status == string(internal.StatusConnected): + case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled(): s.connected = true s.sendNotification = true if s.isUpdateIconActive { @@ -921,6 +923,7 @@ func (s *serviceClient) updateStatus() error { s.mUp.Disable() s.mDown.Enable() s.mNetworks.Enable() + s.mExitNode.Enable() go s.updateExitNodes() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -1274,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() { return } + s.updateIndicationLock.Lock() + defer s.updateIndicationLock.Unlock() + // Update settings menu based on current features - if features != nil && features.DisableUpdateSettings { - s.setSettingsEnabled(false) - } else { - s.setSettingsEnabled(true) + settingsEnabled := features == nil || !features.DisableUpdateSettings + if s.settingsEnabled != settingsEnabled { + s.settingsEnabled = settingsEnabled + s.setSettingsEnabled(settingsEnabled) } // Update profile menu based on current features if s.mProfile != nil { - if features != nil && features.DisableProfiles { - s.mProfile.setEnabled(false) - } else { - s.mProfile.setEnabled(true) + profilesEnabled := features == nil || !features.DisableProfiles + if s.profilesEnabled != profilesEnabled { + s.profilesEnabled = profilesEnabled + s.mProfile.setEnabled(profilesEnabled) } } } From d3b123c76dfab8daac58a1bcd922f4c4eb4c5e10 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Dec 2025 11:22:33 +0100 Subject: [PATCH 110/118] [ci] Add FreeBSD port release job to GitHub Actions (#4916) adds a job that produces new freebsd release files --- .github/workflows/release.yml | 81 +++++++++ release_files/freebsd-port-diff.sh | 216 +++++++++++++++++++++++ release_files/freebsd-port-issue-body.sh | 159 +++++++++++++++++ 3 files changed, 456 insertions(+) create mode 100755 release_files/freebsd-port-diff.sh create mode 100755 release_files/freebsd-port-issue-body.sh diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a9bc1b979..15a200987 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,6 +19,87 @@ concurrency: cancel-in-progress: true jobs: + release_freebsd_port: + name: "FreeBSD Port / Build & Test" + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Generate FreeBSD port diff + run: bash release_files/freebsd-port-diff.sh + + - name: Generate FreeBSD port issue body + run: bash release_files/freebsd-port-issue-body.sh + + - name: Extract version + id: version + run: | + VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/') + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Generated files for version: $VERSION" + cat netbird-*.diff + + - name: Test FreeBSD port + uses: vmactions/freebsd-vm@v1 + with: + usesh: true + copyback: false + release: "15.0" + prepare: | + # Install required packages + pkg install -y git curl portlint go + + # Install Go for building + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" + GO_URL="https://go.dev/dl/$GO_TARBALL" + curl -LO "$GO_URL" + tar -C /usr/local -xzf "$GO_TARBALL" + + # Clone ports tree (shallow, only what we need) + git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports + cd /usr/ports + + run: | + set -e -x + export PATH=$PATH:/usr/local/go/bin + + # Find the diff file + echo "Finding diff file..." + DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1) + echo "Found: $DIFF_FILE" + + if [[ -z "$DIFF_FILE" ]]; then + echo "ERROR: Could not find diff file" + find ~ -name "*.diff" -type f 2>/dev/null || true + exit 1 + fi + + # Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths) + cd /usr/ports + patch -p1 -V none < "$DIFF_FILE" + + # Show patched Makefile + version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}') + + cd /usr/ports/security/netbird + export BATCH=yes + make package + pkg add ./work/pkg/netbird-*.pkg + + netbird version | grep "$version" + + echo "FreeBSD port test completed successfully!" + + - name: Upload FreeBSD port files + uses: actions/upload-artifact@v4 + with: + name: freebsd-port-files + path: | + ./netbird-*-issue.txt + ./netbird-*.diff + retention-days: 30 + release: runs-on: ubuntu-latest-m env: diff --git a/release_files/freebsd-port-diff.sh b/release_files/freebsd-port-diff.sh new file mode 100755 index 000000000..b030b9164 --- /dev/null +++ b/release_files/freebsd-port-diff.sh @@ -0,0 +1,216 @@ +#!/bin/bash +# +# FreeBSD Port Diff Generator for NetBird +# +# This script generates the diff file required for submitting a FreeBSD port update. +# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and +# computing checksums from the Go module proxy. +# +# Usage: ./freebsd-port-diff.sh [new_version] +# Example: ./freebsd-port-diff.sh 0.60.7 +# +# If no version is provided, it fetches the latest from GitHub. + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird" +GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v" +OUTPUT_DIR="${OUTPUT_DIR:-.}" +AWK_FIRST_FIELD='{print $1}' + +fetch_all_tags() { + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \ + grep -E "^DISTVERSION=" | \ + sed 's/DISTVERSION=[[:space:]]*//' | \ + tr -d '\t ' + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + fetch_all_tags | tail -1 + return 0 +} + +fetch_ports_file() { + local filename="$1" + curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null + return 0 +} + +compute_checksums() { + local version="$1" + local tmpdir + tmpdir=$(mktemp -d) + # shellcheck disable=SC2064 + trap "rm -rf '$tmpdir'" EXIT + + echo "Downloading files from Go module proxy for v${version}..." >&2 + + local mod_file="${tmpdir}/v${version}.mod" + local zip_file="${tmpdir}/v${version}.zip" + + curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null + curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null + + if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then + echo "Error: Could not download files from Go module proxy" >&2 + return 1 + fi + + local mod_sha256 mod_size zip_sha256 zip_size + + if command -v sha256sum &>/dev/null; then + mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD") + elif command -v shasum &>/dev/null; then + mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD") + else + echo "Error: No sha256 command found" >&2 + return 1 + fi + + if [[ "$OSTYPE" == "darwin"* ]]; then + mod_size=$(stat -f%z "$mod_file") + zip_size=$(stat -f%z "$zip_file") + else + mod_size=$(stat -c%s "$mod_file") + zip_size=$(stat -c%s "$zip_file") + fi + + echo "TIMESTAMP = $(date +%s)" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}" + return 0 +} + +generate_new_makefile() { + local new_version="$1" + local old_makefile="$2" + + # Check if old version had PORTREVISION + if echo "$old_makefile" | grep -q "^PORTREVISION="; then + # Remove PORTREVISION line and update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \ + grep -v "^PORTREVISION=" + else + # Just update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" + fi + return 0 +} + +# Parse arguments +NEW_VERSION="${1:-}" + +# Auto-detect versions if not provided +OLD_VERSION=$(fetch_current_ports_version) +if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not fetch current version from FreeBSD ports" >&2 + exit 1 +fi +echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2 + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + exit 1 + fi +fi +echo "Target version: ${NEW_VERSION}" >&2 + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2 + exit 0 +fi + +echo "" >&2 + +# Fetch current files +echo "Fetching current Makefile from FreeBSD ports..." >&2 +OLD_MAKEFILE=$(fetch_ports_file "Makefile") +if [[ -z "$OLD_MAKEFILE" ]]; then + echo "Error: Could not fetch Makefile" >&2 + exit 1 +fi + +echo "Fetching current distinfo from FreeBSD ports..." >&2 +OLD_DISTINFO=$(fetch_ports_file "distinfo") +if [[ -z "$OLD_DISTINFO" ]]; then + echo "Error: Could not fetch distinfo" >&2 + exit 1 +fi + +# Generate new files +echo "Generating new Makefile..." >&2 +NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE") + +echo "Computing checksums for new version..." >&2 +NEW_DISTINFO=$(compute_checksums "$NEW_VERSION") +if [[ -z "$NEW_DISTINFO" ]]; then + echo "Error: Could not compute checksums" >&2 + exit 1 +fi + +# Create temp files for diff +TMPDIR=$(mktemp -d) +# shellcheck disable=SC2064 +trap "rm -rf '$TMPDIR'" EXIT + +mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird" + +echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile" +echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo" +echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile" +echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo" + +# Generate diff +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff" + +echo "" >&2 +echo "Generating diff..." >&2 + +# Generate diff and clean up temp paths to show standard a/b paths +(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true + +if [[ ! -s "$OUTPUT_FILE" ]]; then + echo "Error: Generated diff is empty" >&2 + exit 1 +fi + +echo "" >&2 +echo "=========================================" +echo "Diff saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Review the diff above" +echo "2. Submit to https://bugs.freebsd.org/bugzilla/" +echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content" +echo "" +echo "For FreeBSD testing (optional but recommended):" +echo " cd /usr/ports/security/netbird" +echo " patch < ${OUTPUT_FILE}" +echo " make stage && make stage-qa && make package && make install" +echo " netbird status" +echo " make deinstall" diff --git a/release_files/freebsd-port-issue-body.sh b/release_files/freebsd-port-issue-body.sh new file mode 100755 index 000000000..b7ad0f5b1 --- /dev/null +++ b/release_files/freebsd-port-issue-body.sh @@ -0,0 +1,159 @@ +#!/bin/bash +# +# FreeBSD Port Issue Body Generator for NetBird +# +# This script generates the issue body content for submitting a FreeBSD port update +# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/ +# +# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version] +# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1 +# +# If no versions are provided, the script will: +# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree) +# - Fetch NEW version from latest NetBird GitHub release tag + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile" + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + local makefile_content + makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null) + if [[ -z "$makefile_content" ]]; then + echo "Error: Could not fetch Makefile from FreeBSD ports" >&2 + return 1 + fi + echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t ' + return 0 +} + +fetch_all_tags() { + # Fetch tags from GitHub tags page (no rate limiting, no auth needed) + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + local latest + + # Fetch from GitHub tags page + latest=$(fetch_all_tags | tail -1) + + if [[ -z "$latest" ]]; then + # Fallback to GitHub API + latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \ + grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/') + fi + + if [[ -z "$latest" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + return 1 + fi + echo "$latest" + return 0 +} + +OLD_VERSION="${1:-}" +NEW_VERSION="${2:-}" + +if [[ -z "$OLD_VERSION" ]]; then + OLD_VERSION=$(fetch_current_ports_version) + if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not determine old version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2 +fi + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not determine new version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected NEW version from GitHub: $NEW_VERSION" >&2 +fi + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2 +fi + +echo "" >&2 + +OUTPUT_DIR="${OUTPUT_DIR:-.}" + +fetch_releases_between_versions() { + echo "Fetching release history from GitHub..." >&2 + + # Fetch all tags and filter to those between OLD and NEW versions + fetch_all_tags | \ + while read -r ver; do + if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \ + [[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \ + [[ "$ver" != "$OLD_VERSION" ]]; then + echo "$ver" + fi + done + return 0 +} + +generate_changelog_section() { + local releases + releases=$(fetch_releases_between_versions) + + echo "Changelogs:" + if [[ -n "$releases" ]]; then + echo "$releases" | while read -r ver; do + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}" + done + else + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}" + fi + return 0 +} + +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt" + +cat << EOF > "$OUTPUT_FILE" +BUGZILLA ISSUE DETAILS +====================== + +Severity: Affects Some People + +Summary: security/netbird: Update to ${NEW_VERSION} + +Description: +------------ +security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION} + +$(generate_changelog_section) + +Commit log: +https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION} +EOF + +echo "=========================================" +echo "Issue body saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login" +echo "2. Click 'Report an update or defect to a port'" +echo "3. Fill in:" +echo " - Severity: Affects Some People" +echo " - Summary: security/netbird: Update to ${NEW_VERSION}" +echo " - Description: Copy content from ${OUTPUT_FILE}" +echo "4. Attach diff file: netbird-${NEW_VERSION}.diff" +echo "5. Submit the bug report" From ab6a9e85de3ea8ea669b27fcf3f38fcbc513799c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Dec 2025 18:03:14 +0100 Subject: [PATCH 111/118] [misc] Use new sign pipelines 0.1.0 (#4993) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 15a200987..345728285 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.23" + SIGN_PIPE_VER: "v0.1.0" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" From aa914a0f268e675cea27820875ad26b671a82568 Mon Sep 17 00:00:00 2001 From: August <112662403+hey-august@users.noreply.github.com> Date: Wed, 24 Dec 2025 12:06:35 -0500 Subject: [PATCH 112/118] [docs] Fix broken image link (#4876) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2c5ee2ab6..ebf108cdb 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.

- +

See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. From 33d1761fe89be0f8a8e6a8d9c6a8f83c634d8ea3 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Mon, 29 Dec 2025 04:43:57 -0700 Subject: [PATCH 113/118] Apply DNS host config on change only (#4695) Adds a per-instance uint64 hash to DefaultServer to detect identical merged host DNS configs (including extra domains). applyHostConfig computes and compares the hash, skips applying if unchanged, treats hash errors as a fail-safe (proceed to apply), and updates the stored hash only after successful hashing and apply. --- client/internal/dns/server.go | 23 +++++++++++++++++++++++ client/internal/dns/server_test.go | 10 ++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index afaf0579f..94945b55a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -80,6 +80,7 @@ type DefaultServer struct { updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + currentConfigHash uint64 handlerChain *HandlerChain extraDomains map[domain.Domain]int @@ -207,6 +208,7 @@ func newDefaultServer( hostsDNSHolder: newHostsDNSHolder(), hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, + currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied } // register with root zone, handler chain takes care of the routing @@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() { log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains)) + hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err) + // Fall through to apply config anyway (fail-safe approach) + } else if s.currentConfigHash == hash { + log.Debugf("not applying host config as there are no changes") + return + } + + log.Debugf("applying host config as there are changes") if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) + return + } + + // Only update hash if it was computed successfully and config was applied + if err == nil { + s.currentConfigHash = hash } s.registerFallback(config) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index d12070128..fe1f67f66 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) { "other.example.com.", "duplicate.example.com.", }, - applyHostConfigCall: 4, + // Expect 3 calls instead of 4 because when deregistering duplicate.example.com, + // the domain remains in the config (ref count goes from 2 to 1), so the host + // config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 3, }, { name: "Config update with new domains after registration", @@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) { expectedMatchOnly: []string{ "extra.example.com.", }, - applyHostConfigCall: 3, + // Expect 2 calls instead of 3 because when deregistering protected.example.com, + // it's removed from extraDomains but still remains in the config (from customZones), + // so the host config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 2, }, { name: "Register domain that is part of nameserver group", From 73201c4f3e6973fa3c0ee7d29e921af05498b519 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 29 Dec 2025 12:47:38 +0100 Subject: [PATCH 114/118] Add conditional checks for FreeBSD diff file generation in release workflow (#5001) --- .github/workflows/release.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 345728285..2fa847dce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -32,7 +32,18 @@ jobs: - name: Generate FreeBSD port issue body run: bash release_files/freebsd-port-issue-body.sh + - name: Check if diff was generated + id: check_diff + run: | + if ls netbird-*.diff 1> /dev/null 2>&1; then + echo "diff_exists=true" >> $GITHUB_OUTPUT + else + echo "diff_exists=false" >> $GITHUB_OUTPUT + echo "No diff file generated (port may already be up to date)" + fi + - name: Extract version + if: steps.check_diff.outputs.diff_exists == 'true' id: version run: | VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/') @@ -41,6 +52,7 @@ jobs: cat netbird-*.diff - name: Test FreeBSD port + if: steps.check_diff.outputs.diff_exists == 'true' uses: vmactions/freebsd-vm@v1 with: usesh: true @@ -92,6 +104,7 @@ jobs: echo "FreeBSD port test completed successfully!" - name: Upload FreeBSD port files + if: steps.check_diff.outputs.diff_exists == 'true' uses: actions/upload-artifact@v4 with: name: freebsd-port-files From 67f7b2404ede1f1ff5d5b4aab09546c074777733 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 29 Dec 2025 12:50:41 +0100 Subject: [PATCH 115/118] [client, management] Feature/ssh fine grained access (#4969) Add fine-grained SSH access control with authorized users/groups --- client/internal/engine.go | 2 + client/internal/engine_ssh.go | 38 + client/ssh/auth/auth.go | 184 +++ client/ssh/auth/auth_test.go | 612 ++++++++++ client/ssh/proxy/proxy_test.go | 25 +- client/ssh/server/jwt_test.go | 18 + client/ssh/server/server.go | 51 +- .../network_map/controller/controller.go | 10 +- management/internals/modules/peers/manager.go | 2 + .../internals/shared/grpc/conversion.go | 50 +- management/internals/shared/grpc/server.go | 2 +- management/server/account.go | 24 +- management/server/account_test.go | 2 +- .../http/handlers/peers/peers_handler.go | 30 +- .../handlers/policies/policies_handler.go | 18 + management/server/peer.go | 4 +- management/server/policy_test.go | 32 +- management/server/store/sql_store.go | 9 +- management/server/types/account.go | 113 +- management/server/types/account_test.go | 187 +++ management/server/types/network.go | 2 + .../server/types/networkmap_golden_test.go | 18 +- management/server/types/policy.go | 4 + management/server/types/policyrule.go | 12 + management/server/user.go | 8 +- management/server/user_test.go | 8 +- shared/management/http/api/openapi.yml | 55 +- shared/management/http/api/types.gen.go | 75 +- shared/management/proto/management.pb.go | 1006 ++++++++++------- shared/management/proto/management.proto | 18 + shared/sshauth/userhash.go | 28 + shared/sshauth/userhash_test.go | 210 ++++ 32 files changed, 2345 insertions(+), 512 deletions(-) create mode 100644 client/ssh/auth/auth.go create mode 100644 client/ssh/auth/auth_test.go create mode 100644 shared/sshauth/userhash.go create mode 100644 shared/sshauth/userhash_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 084bafdbd..55645b494 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1151,6 +1151,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { log.Warnf("failed to update SSH client config: %v", err) } + + e.updateSSHServerAuth(networkMap.GetSshAuth()) } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 861b3d6d2..e683d8cee 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -11,15 +11,18 @@ import ( firewallManager "github.com/netbirdio/netbird/client/firewall/manager" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" sshconfig "github.com/netbirdio/netbird/client/ssh/config" sshserver "github.com/netbirdio/netbird/client/ssh/server" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) type sshServer interface { Start(ctx context.Context, addr netip.AddrPort) error Stop() error GetStatus() (bool, []sshserver.SessionInfo) + UpdateSSHAuth(config *sshauth.Config) } func (e *Engine) setupSSHPortRedirection() error { @@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio return sshServer.GetStatus() } + +// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server +func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) { + if sshAuth == nil { + return + } + + if e.sshServer == nil { + return + } + + protoUsers := sshAuth.GetAuthorizedUsers() + authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers)) + for i, hash := range protoUsers { + if len(hash) != 16 { + log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash)) + return + } + authorizedUsers[i] = sshuserhash.UserIDHash(hash) + } + + machineUsers := make(map[string][]uint32) + for osUser, indexes := range sshAuth.GetMachineUsers() { + machineUsers[osUser] = indexes.GetIndexes() + } + + // Update SSH server with new authorization configuration + authConfig := &sshauth.Config{ + UserIDClaim: sshAuth.GetUserIDClaim(), + AuthorizedUsers: authorizedUsers, + MachineUsers: machineUsers, + } + + e.sshServer.UpdateSSHAuth(authConfig) +} diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go new file mode 100644 index 000000000..488b6e12e --- /dev/null +++ b/client/ssh/auth/auth.go @@ -0,0 +1,184 @@ +package auth + +import ( + "errors" + "fmt" + "sync" + + log "github.com/sirupsen/logrus" + + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +const ( + // DefaultUserIDClaim is the default JWT claim used to extract user IDs + DefaultUserIDClaim = "sub" + // Wildcard is a special user ID that matches all users + Wildcard = "*" +) + +var ( + ErrEmptyUserID = errors.New("JWT user ID is empty") + ErrUserNotAuthorized = errors.New("user is not authorized to access this peer") + ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user") + ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user") +) + +// Authorizer handles SSH fine-grained access control authorization +type Authorizer struct { + // UserIDClaim is the JWT claim to extract the user ID from + userIDClaim string + + // authorizedUsers is a list of hashed user IDs authorized to access this peer + authorizedUsers []sshuserhash.UserIDHash + + // machineUsers maps OS login usernames to lists of authorized user indexes + machineUsers map[string][]uint32 + + // mu protects the list of users + mu sync.RWMutex +} + +// Config contains configuration for the SSH authorizer +type Config struct { + // UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email") + UserIDClaim string + + // AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer + AuthorizedUsers []sshuserhash.UserIDHash + + // MachineUsers maps OS login usernames to indexes in AuthorizedUsers + // If a user wants to login as a specific OS user, their index must be in the corresponding list + MachineUsers map[string][]uint32 +} + +// NewAuthorizer creates a new SSH authorizer with empty configuration +func NewAuthorizer() *Authorizer { + a := &Authorizer{ + userIDClaim: DefaultUserIDClaim, + machineUsers: make(map[string][]uint32), + } + + return a +} + +// Update updates the authorizer configuration with new values +func (a *Authorizer) Update(config *Config) { + a.mu.Lock() + defer a.mu.Unlock() + + if config == nil { + // Clear authorization + a.userIDClaim = DefaultUserIDClaim + a.authorizedUsers = []sshuserhash.UserIDHash{} + a.machineUsers = make(map[string][]uint32) + log.Info("SSH authorization cleared") + return + } + + userIDClaim := config.UserIDClaim + if userIDClaim == "" { + userIDClaim = DefaultUserIDClaim + } + a.userIDClaim = userIDClaim + + // Store authorized users list + a.authorizedUsers = config.AuthorizedUsers + + // Store machine users mapping + machineUsers := make(map[string][]uint32) + for osUser, indexes := range config.MachineUsers { + if len(indexes) > 0 { + machineUsers[osUser] = indexes + } + } + a.machineUsers = machineUsers + + log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings", + len(config.AuthorizedUsers), len(machineUsers)) +} + +// Authorize validates if a user is authorized to login as the specified OS user +// Returns nil if authorized, or an error describing why authorization failed +func (a *Authorizer) Authorize(jwtUserID, osUsername string) error { + if jwtUserID == "" { + log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername) + return ErrEmptyUserID + } + + // Hash the JWT user ID for comparison + hashedUserID, err := sshuserhash.HashUserID(jwtUserID) + if err != nil { + log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err) + return fmt.Errorf("failed to hash user ID: %w", err) + } + + a.mu.RLock() + defer a.mu.RUnlock() + + // Find the index of this user in the authorized list + userIndex, found := a.findUserIndex(hashedUserID) + if !found { + log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername) + return ErrUserNotAuthorized + } + + return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex) +} + +// checkMachineUserMapping validates if a user's index is authorized for the specified OS user +// Checks wildcard mapping first, then specific OS user mappings +func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error { + // If wildcard exists and user's index is in the wildcard list, allow access to any OS user + if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard { + if a.isIndexInList(uint32(userIndex), wildcardIndexes) { + log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex) + return nil + } + } + + // Check for specific OS username mapping + allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername] + if !hasMachineUserMapping { + // No mapping for this OS user - deny by default (fail closed) + log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID) + return ErrNoMachineUserMapping + } + + // Check if user's index is in the allowed indexes for this specific OS user + if !a.isIndexInList(uint32(userIndex), allowedIndexes) { + log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex) + return ErrUserNotMappedToOSUser + } + + log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex) + return nil +} + +// GetUserIDClaim returns the JWT claim name used to extract user IDs +func (a *Authorizer) GetUserIDClaim() string { + a.mu.RLock() + defer a.mu.RUnlock() + return a.userIDClaim +} + +// findUserIndex finds the index of a hashed user ID in the authorized users list +// Returns the index and true if found, 0 and false if not found +func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) { + for i, id := range a.authorizedUsers { + if id == hashedUserID { + return i, true + } + } + return 0, false +} + +// isIndexInList checks if an index exists in a list of indexes +func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool { + for _, idx := range indexes { + if idx == index { + return true + } + } + return false +} diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go new file mode 100644 index 000000000..2b3b5a414 --- /dev/null +++ b/client/ssh/auth/auth_test.go @@ -0,0 +1,612 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/sshauth" +) + +func TestAuthorizer_Authorize_UserNotInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list with one user + authorizedUserHash, err := sshauth.HashUserID("authorized-user") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Try to authorize a different user + err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed) + } + authorizer.Update(config) + + // All attempts should fail when no machine user mappings exist (fail closed) + err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should access root and admin + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user1", "admin") + assert.NoError(t, err) + + // user2 (index 1) should access root and postgres + err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user3 (index 2) should access postgres + err = authorizer.Authorize("user3", "postgres") + assert.NoError(t, err) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should NOT access postgres + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 (index 1) should NOT access admin + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access root + err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access admin + err = authorizer.Authorize("user3", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) +} + +func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only root is mapped + }, + } + authorizer.Update(config) + + // user1 should NOT access an unmapped OS user (fail closed) + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Empty user ID should fail + err = authorizer.Authorize("", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrEmptyUserID) +} + +func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up multiple authorized users + userHashes := make([]sshauth.UserIDHash, 10) + for i := 0; i < 10; i++ { + hash, err := sshauth.HashUserID("user" + string(rune('0'+i))) + require.NoError(t, err) + userHashes[i] = hash + } + + // Create machine user mapping for all users + rootIndexes := make([]uint32, 10) + for i := 0; i < 10; i++ { + rootIndexes[i] = uint32(i) + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": rootIndexes, + }, + } + authorizer.Update(config) + + // All users should be authorized for root + for i := 0; i < 10; i++ { + err := authorizer.Authorize("user"+string(rune('0'+i)), "root") + assert.NoError(t, err, "user%d should be authorized", i) + } + + // User not in list should fail + err := authorizer.Authorize("unknown-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up initial configuration + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{"root": {0}}, + } + authorizer.Update(config) + + // user1 should be authorized + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Clear configuration + authorizer.Update(nil) + + // user1 should no longer be authorized + err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Machine users with empty index lists should be filtered out + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + "postgres": {}, // empty list - should be filtered out + "admin": nil, // nil list - should be filtered out + }, + } + authorizer.Update(config) + + // root should work + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // postgres should fail (no mapping) + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // admin should fail (no mapping) + err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_CustomUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up with custom user ID claim + user1Hash, err := sshauth.HashUserID("user@example.com") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "email", + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + }, + } + authorizer.Update(config) + + // Verify the custom claim is set + assert.Equal(t, "email", authorizer.GetUserIDClaim()) + + // Authorize with email as user ID + err = authorizer.Authorize("user@example.com", "root") + assert.NoError(t, err) +} + +func TestAuthorizer_DefaultUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Verify default claim + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) + assert.Equal(t, "sub", authorizer.GetUserIDClaim()) + + // Set up with empty user ID claim (should use default) + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "", // empty - should use default + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Should fall back to default + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) +} + +func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) { + authorizer := NewAuthorizer() + + // Create a large authorized users list + const numUsers = 1000 + userHashes := make([]sshauth.UserIDHash, numUsers) + for i := 0; i < numUsers; i++ { + hash, err := sshauth.HashUserID("user" + string(rune(i))) + require.NoError(t, err) + userHashes[i] = hash + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": {0, 500, 999}, // first, middle, and last user + }, + } + authorizer.Update(config) + + // First user should have access + err := authorizer.Authorize("user"+string(rune(0)), "root") + assert.NoError(t, err) + + // Middle user should have access + err = authorizer.Authorize("user"+string(rune(500)), "root") + assert.NoError(t, err) + + // Last user should have access + err = authorizer.Authorize("user"+string(rune(999)), "root") + assert.NoError(t, err) + + // User not in mapping should NOT have access + err = authorizer.Authorize("user"+string(rune(100)), "root") + assert.Error(t, err) +} + +func TestAuthorizer_ConcurrentAuthorization(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, + }, + } + authorizer.Update(config) + + // Test concurrent authorization calls (should be safe to read concurrently) + const numGoroutines = 100 + errChan := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + user := "user1" + if idx%2 == 0 { + user = "user2" + } + err := authorizer.Authorize(user, "root") + errChan <- err + }(i) + } + + // Wait for all goroutines to complete and collect errors + for i := 0; i < numGoroutines; i++ { + err := <-errChan + assert.NoError(t, err) + } +} + +func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + // Configure with wildcard - all authorized users can access any OS user + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1, 2}, // wildcard with all user indexes + }, + } + authorizer.Update(config) + + // All authorized users should be able to access any OS user + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + err = authorizer.Authorize("user3", "admin") + assert.NoError(t, err) + + err = authorizer.Authorize("user1", "ubuntu") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "nginx") + assert.NoError(t, err) + + err = authorizer.Authorize("user3", "docker") + assert.NoError(t, err) +} + +func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Configure with wildcard + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, + }, + } + authorizer.Update(config) + + // user1 should have access + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Unauthorized user should still be denied even with wildcard + err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with both wildcard and specific mappings + // Wildcard takes precedence for users in the wildcard index list + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1}, // wildcard for both users + "root": {0}, // specific mapping that would normally restrict to user1 only + }, + } + authorizer.Update(config) + + // Both users should be able to access root via wildcard (takes precedence over specific mapping) + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + // Both users should be able to access any other OS user via wildcard + err = authorizer.Authorize("user1", "postgres") + assert.NoError(t, err) + + err = authorizer.Authorize("user2", "admin") + assert.NoError(t, err) +} + +func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure WITHOUT wildcard - only specific mappings + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only user1 + "postgres": {1}, // only user2 + }, + } + authorizer.Update(config) + + // user1 can access root + err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // user2 can access postgres + err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user1 cannot access postgres + err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 cannot access root + err = authorizer.Authorize("user2", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // Neither can access unmapped OS users + err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) { + // This test covers the scenario where wildcard exists with limited indexes. + // Only users whose indexes are in the wildcard list can access any OS user via wildcard. + // Other users can only access OS users they are explicitly mapped to. + authorizer := NewAuthorizer() + + // Create two authorized user hashes (simulating the base64-encoded hashes in the config) + wasmHash, err := sshauth.HashUserID("wasm") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with wildcard having only index 0, and specific mappings for other OS users + config := &Config{ + UserIDClaim: "sub", + AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, // wildcard with only index 0 - only wasm has wildcard access + "alice": {1}, // specific mapping for user2 + "bob": {1}, // specific mapping for user2 + }, + } + authorizer.Update(config) + + // wasm (index 0) should access any OS user via wildcard + err = authorizer.Authorize("wasm", "root") + assert.NoError(t, err, "wasm should access root via wildcard") + + err = authorizer.Authorize("wasm", "alice") + assert.NoError(t, err, "wasm should access alice via wildcard") + + err = authorizer.Authorize("wasm", "bob") + assert.NoError(t, err, "wasm should access bob via wildcard") + + err = authorizer.Authorize("wasm", "postgres") + assert.NoError(t, err, "wasm should access postgres via wildcard") + + // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres + err = authorizer.Authorize("user2", "alice") + assert.NoError(t, err, "user2 should access alice via explicit mapping") + + err = authorizer.Authorize("user2", "bob") + assert.NoError(t, err, "user2 should access bob via explicit mapping") + + err = authorizer.Authorize("user2", "root") + assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + err = authorizer.Authorize("user2", "postgres") + assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // Unauthorized user should still be denied + err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied") +} diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index 582f9c07b..81d588801 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -27,9 +27,11 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/client/ssh/testutil" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) func TestMain(m *testing.M) { @@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) { sshServer := server.New(serverConfig) sshServer.SetAllowRootLogin(true) + // Configure SSH authorization for the test user + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, // Index 0 in AuthorizedUsers + }, + } + sshServer.UpdateSSHAuth(authConfig) + sshServerAddr := server.StartTestServer(t, sshServer) defer func() { _ = sshServer.Stop() }() @@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) { mockDaemon.setHostKey(host, hostPubKey) - validToken := generateValidJWT(t, privateKey, issuer, audience) + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) mockDaemon.setJWTToken(validToken) - proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) require.NoError(t, err) clientConn, proxyConn := net.Pipe() @@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { return privateKey, jwksJSON } -func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string { t.Helper() claims := jwt.MapClaims{ "iss": issuer, "aud": audience, - "sub": "test-user", + "sub": user, "exp": time.Now().Add(time.Hour).Unix(), "iat": time.Now().Unix(), } diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index 1f3bac76d..d36d7cbbf 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -23,10 +23,12 @@ import ( "github.com/stretchr/testify/require" nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/client" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/testutil" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) func TestJWTEnforcement(t *testing.T) { @@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) { tc.setupServer(server) } + // Always set up authorization for test-user to ensure tests fail at JWT validation stage + testUserHash, err := sshuserhash.HashUserID("test-user") + require.NoError(t, err) + + // Get current OS username for machine user mapping + currentUser := testutil.GetTestUsername(t) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + currentUser: {0}, // Allow test-user (index 0) to access current OS user + }, + } + server.UpdateSSHAuth(authConfig) + serverAddr := StartTestServer(t, server) defer require.NoError(t, server.Stop()) diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 37763ee0e..82718d002 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth/jwt" @@ -138,6 +139,8 @@ type Server struct { jwtExtractor *jwt.ClaimsExtractor jwtConfig *JWTConfig + authorizer *sshauth.Authorizer + suSupportsPty bool loginIsUtilLinux bool } @@ -179,6 +182,7 @@ func New(config *Config) *Server { sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), jwtEnabled: config.JWT != nil, jwtConfig: config.JWT, + authorizer: sshauth.NewAuthorizer(), // Initialize with empty config } return s @@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) { s.wgAddress = addr } +// UpdateSSHAuth updates the SSH fine-grained access control configuration +// This should be called when network map updates include new SSH auth configuration +func (s *Server) UpdateSSHAuth(config *sshauth.Config) { + s.mu.Lock() + defer s.mu.Unlock() + + // Reset JWT validator/extractor to pick up new userIDClaim + s.jwtValidator = nil + s.jwtExtractor = nil + + s.authorizer.Update(config) +} + // ensureJWTValidator initializes the JWT validator and extractor if not already initialized func (s *Server) ensureJWTValidator() error { s.mu.RLock() @@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error { return nil } config := s.jwtConfig + authorizer := s.authorizer s.mu.RUnlock() if config == nil { @@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error { true, ) - extractor := jwt.NewClaimsExtractor( + // Use custom userIDClaim from authorizer if available + extractorOptions := []jwt.ClaimsExtractorOption{ jwt.WithAudience(config.Audience), - ) + } + if authorizer.GetUserIDClaim() != "" { + extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim())) + log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim()) + } + + extractor := jwt.NewClaimsExtractor(extractorOptions...) s.mu.Lock() defer s.mu.Unlock() @@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int } func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { + osUsername := ctx.User() + remoteAddr := ctx.RemoteAddr() + if err := s.ensureJWTValidator(); err != nil { - log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err) return false } token, err := s.validateJWTToken(password) if err != nil { - log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err) return false } userAuth, err := s.extractAndValidateUser(token) if err != nil { - log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err) return false } - key := newAuthKey(ctx.User(), ctx.RemoteAddr()) + s.mu.RLock() + authorizer := s.authorizer + s.mu.RUnlock() + + if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil { + log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err) + return false + } + + key := newAuthKey(osUsername, remoteAddr) s.mu.Lock() s.pendingAuthJWT[key] = userAuth.UserId s.mu.Unlock() - log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) + log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr) return true } diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 022ea774c..df16e1922 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() if c.experimentalNetworkMap(accountID) { c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) @@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin if c.experimentalNetworkMap(accountID) { remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() postureChecks, err := c.getPeerPostureChecks(account, peerId) if err != nil { @@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe if c.experimentalNetworkMap(accountId) { remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr if c.experimentalNetworkMap(accountID) { networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers()) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N if c.experimentalNetworkMap(peer.AccountID) { networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index e82f19e63..b200b9663 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs } } + m.accountManager.UpdateAccountPeers(ctx, accountID) + return nil } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 1a0dc009b..f984c73df 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -6,7 +6,10 @@ import ( "net/url" "strings" + log "github.com/sirupsen/logrus" + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" @@ -16,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/sshauth" ) func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { @@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) sshConfig := &proto.SSHConfig{ - SshEnabled: peer.SSHEnabled, + SshEnabled: peer.SSHEnabled || enableSSH, } - if peer.SSHEnabled { + if sshConfig.SshEnabled { sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } @@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), }, Checks: toProtocolChecks(ctx, checks), } @@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb response.NetworkMap.ForwardingRules = forwardingRules } + if networkMap.AuthorizedUsers != nil { + hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers) + userIDClaim := auth.DefaultUserIDClaim + if httpConfig != nil && httpConfig.AuthUserIDClaim != "" { + userIDClaim = httpConfig.AuthUserIDClaim + } + response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim} + } + return response } +func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { + userIDToIndex := make(map[string]uint32) + var hashedUsers [][]byte + machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) + + for machineUser, users := range authorizedUsers { + indexes := make([]uint32, 0, len(users)) + for userID := range users { + idx, exists := userIDToIndex[userID] + if !exists { + hash, err := sshauth.HashUserID(userID) + if err != nil { + log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) + continue + } + idx = uint32(len(hashedUsers)) + userIDToIndex[userID] = idx + hashedUsers = append(hashedUsers, hash[:]) + } + indexes = append(indexes, idx) + } + machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes} + } + + return hashedUsers, machineUsers +} + func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { for _, rPeer := range peers { dst = append(dst, &proto.RemotePeerConfig{ diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 462e2e6eb..ad6b34c5f 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -635,7 +635,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH), Checks: toProtocolChecks(ctx, postureChecks), } diff --git a/management/server/account.go b/management/server/account.go index 04f83842e..405a3c0f6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } - if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) - if err != nil { - return err - } + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) + if err != nil { + return err + } - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) - if err != nil { - return err - } + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) + if err != nil { + return err + } - if removedGroupAffectsPeers || newGroupsAffectsPeers { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) - } + if removedGroupAffectsPeers || newGroupsAffectsPeers { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } return nil diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f125e3a0..25818ada2 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index f531f0cdb..a5c9ab0ac 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } @@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) PortRanges: []types.RulePortRange{portRange}, }}, } + if protocol == types.PolicyRuleProtocolNetbirdSSH { + policy.Rules[0].AuthorizedUser = userAuth.UserId + } _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) if err != nil { @@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } if !approved { @@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn if osVersion == "" { osVersion = peer.Meta.Core } - return &api.PeerBatch{ CreatedAt: peer.CreatedAt, Id: peer.ID, @@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } } diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index ab1639ab1..e4d1d73df 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: pr.Protocol = types.PolicyRuleProtocolICMP + case api.PolicyRuleUpdateProtocolNetbirdSsh: + pr.Protocol = types.PolicyRuleProtocolNetbirdSSH default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s } } + if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 { + for _, sourceGroupID := range pr.Sources { + _, ok := (*rule.AuthorizedGroups)[sourceGroupID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w) + return + } + } + pr.AuthorizedGroups = *rule.AuthorizedGroups + } + // validate policy object if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP { if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { @@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { DestinationResource: r.DestinationResource.ToAPIResponse(), } + if len(r.AuthorizedGroups) != 0 { + authorizedGroupsCopy := r.AuthorizedGroups + rule.AuthorizedGroups = &authorizedGroupsCopy + } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy diff --git a/management/server/peer.go b/management/server/peer.go index 49f5bf2a5..7c48a8052 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers()) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun } for _, p := range userPeers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers()) for _, aclPeer := range aclPeers { if aclPeer.ID == peer.ID { return peer, nil diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 90fe8f036..a3f987732 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers()) assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 8) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }) t.Run("check port ranges support for older peers", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 1) assert.Contains(t, peers, account.Peers["peerI"]) @@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ @@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2b8981b97..da05ba416 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1910,16 +1910,16 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t if len(policyIDs) == 0 { return nil, nil } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)` rows, err := s.pool.Query(ctx, query, policyIDs) if err != nil { return nil, err } rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte + var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser) if err == nil { if enabled.Valid { r.Enabled = enabled.Bool @@ -1945,6 +1945,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t if portRanges != nil { _ = json.Unmarshal(portRanges, &r.PortRanges) } + if authorizedGroups != nil { + _ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups) + } } return &r, err }) diff --git a/management/server/types/account.go b/management/server/types/account.go index 9e86d8936..c43e0bb57 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -16,6 +16,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -45,8 +46,10 @@ const ( // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. nativeSSHPortString = "22022" + nativeSSHPortNumber = 22022 // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. defaultSSHPortString = "22" + defaultSSHPortNumber = 22 ) type supportedFeatures struct { @@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap( resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, metrics *telemetry.AccountManagerMetrics, + groupIDToUserIDs map[string][]string, ) *NetworkMap { start := time.Now() peer := a.Peers[peerID] @@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) + aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap( OfflinePeers: expiredPeers, FirewallRules: firewallRules, RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + AuthorizedUsers: authorizedUsers, + EnableSSH: enableSSH, } if metrics != nil { @@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map // GetPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer) + authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs + sshEnabled := false for _, policy := range a.Policies { if !policy.Enabled { @@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P if peerInDestinations { generateResources(rule, sourcePeers, FirewallRuleDirectionIN) } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := groupIDToUserIDs[groupID] + if !ok { + log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID) + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } } } - return getAccumulatedResources() + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (a *Account) getAllowedUserIDs() map[string]struct{} { + users := make(map[string]struct{}) + for _, nbUser := range a.Users { + if !nbUser.IsBlocked() && !nbUser.IsServiceUser { + users[nbUser.Id] = struct{}{} + } + } + return users } // connResourcesGenerator returns generator and accumulator function which returns the result of generator calls @@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer peersExists[peer.ID] = struct{}{} } + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + fr := FirewallRule{ PolicyID: rule.ID, PeerIP: peer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), + Protocol: string(protocol), } ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + @@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer } } +func policyRuleImpliesLegacySSH(rule *PolicyRule) bool { + return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges))) +} + +func portRangeIncludesSSH(portRanges []RulePortRange) bool { + for _, pr := range portRanges { + if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) { + return true + } + } + return false +} + +func portsIncludesSSH(ports []string) bool { + for _, port := range ports { + if port == defaultSSHPortString || port == nativeSSHPortString { + return true + } + } + return false +} + // getAllPeersFromGroups for given peer ID and list of groups // // Returns a list of peers from specified groups that pass specified posture checks @@ -1660,6 +1743,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error { return nil } +func (a *Account) GetActiveGroupUsers() map[string][]string { + allGroupID := "" + group, err := a.GetGroupAll() + if err != nil { + log.Errorf("failed to get group all: %v", err) + } else { + allGroupID = group.ID + } + groups := make(map[string][]string, len(a.GroupsG)) + for _, user := range a.Users { + if !user.IsBlocked() && !user.IsServiceUser { + for _, groupID := range user.AutoGroups { + groups[groupID] = append(groups[groupID], user.Id) + } + groups[allGroupID] = append(groups[allGroupID], user.Id) + } + } + return groups +} + // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) @@ -1691,7 +1794,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer expanded = append(expanded, &fr) } - if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) { + if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH { expanded = addNativeSSHRule(base, expanded) } diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f9aa6a1c2..2c9f2428d 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { } } +func Test_GetActiveGroupUsers(t *testing.T) { + tests := []struct { + name string + account *Account + expected map[string][]string + }{ + { + name: "all users are active", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1", "user2"}, + "group3": {"user2"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "some users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1", "group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1"}, + "group3": {"user3"}, + "": {"user1", "user3"}, + }, + }, + { + name: "all users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: true, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2"}, + Blocked: true, + }, + }, + }, + expected: map[string][]string{}, + }, + { + name: "user with no auto groups", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user2"}, + "": {"user1", "user2"}, + }, + }, + { + name: "empty account", + account: &Account{ + Users: map[string]*User{}, + }, + expected: map[string][]string{}, + }, + { + name: "multiple users in same group", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user2", "user3"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "user in multiple groups with blocked users", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2", "group3"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1", "group2"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1"}, + "group2": {"user1"}, + "group3": {"user1", "user3"}, + "": {"user1", "user3"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetActiveGroupUsers() + + // Check that the number of groups matches + assert.Equal(t, len(tt.expected), len(result), "number of groups should match") + + // Check each group's users + for groupID, expectedUsers := range tt.expected { + actualUsers, exists := result[groupID] + assert.True(t, exists, "group %s should exist in result", groupID) + assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID) + } + + // Ensure no extra groups in result + for groupID := range result { + _, exists := tt.expected[groupID] + assert.True(t, exists, "unexpected group %s in result", groupID) + } + }) + } +} + func Test_FilterZoneRecordsForPeers(t *testing.T) { tests := []struct { name string diff --git a/management/server/types/network.go b/management/server/types/network.go index ffc019565..d3708d80a 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -38,6 +38,8 @@ type NetworkMap struct { FirewallRules []*FirewallRule RoutesFirewallRules []*RouteFirewallRule ForwardingRules []*ForwardingRule + AuthorizedUsers map[string]map[string]struct{} + EnableSSH bool } func (nm *NetworkMap) Merge(other *NetworkMap) { diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go index d85aaabb2..913094e4c 100644 --- a/management/server/types/networkmap_golden_test.go +++ b/management/server/types/networkmap_golden_test.go @@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) normalizeAndSortNetworkMap(networkMap) @@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) { b.Run("old builder", func(b *testing.B) { for range b.N { for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) normalizeAndSortNetworkMap(networkMap) @@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) normalizeAndSortNetworkMap(networkMap) @@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) normalizeAndSortNetworkMap(networkMap) @@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) normalizeAndSortNetworkMap(networkMap) @@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { b.Run("old builder after delete", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5e86a87c6..d4e1a8816 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -23,6 +23,8 @@ const ( PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") // PolicyRuleProtocolICMP type of traffic PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") + // PolicyRuleProtocolNetbirdSSH type of traffic + PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh") ) const ( @@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) protocol = PolicyRuleProtocolUDP case "icmp": return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + case "netbird-ssh": + return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil default: return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) } diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index 2643ae45c..bb75dd555 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -80,6 +80,12 @@ type PolicyRule struct { // PortRanges a list of port ranges. PortRanges []RulePortRange `gorm:"serializer:json"` + + // AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh + AuthorizedGroups map[string][]string `gorm:"serializer:json"` + + // AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh + AuthorizedUser string } // Copy returns a copy of a policy rule @@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule { Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), PortRanges: make([]RulePortRange, len(pm.PortRanges)), + AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)), + AuthorizedUser: pm.AuthorizedUser, } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) copy(rule.PortRanges, pm.PortRanges) + for k, v := range pm.AuthorizedGroups { + rule.AuthorizedGroups[k] = make([]string, len(v)) + copy(rule.AuthorizedGroups[k], v) + } return rule } diff --git a/management/server/user.go b/management/server/user.go index ca02f91e6..9d4620462 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( + _, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - if userHadPeers { - updateAccountPeers = true - } + updateAccountPeers = true err = transaction.SaveUser(ctx, updatedUser) if err != nil { @@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } } - if settings.GroupsPropagationEnabled && updateAccountPeers { + if updateAccountPeers { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } diff --git a/management/server/user_test.go b/management/server/user_test.go index 0d778cfa2..3032ee3e8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { updateManager.CloseChannel(context.Background(), peer1.ID) }) - // Creating a new regular user should not update account peers and not send peer update + // Creating a new regular user should send peer update (as users are not filtered yet) t.Run("creating new regular user with no groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - // updating user with no linked peers should not update account peers and not send peer update + // updating user with no linked peers should update account peers and send peer update (as users are not filtered yet) t.Run("updating user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 42f2fd89e..c9edcdda6 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -488,6 +488,8 @@ components: description: Indicates whether the peer is ephemeral or not type: boolean example: false + local_flags: + $ref: '#/components/schemas/PeerLocalFlags' required: - city_name - connected @@ -514,6 +516,49 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerLocalFlags: + type: object + properties: + rosenpass_enabled: + description: Indicates whether Rosenpass is enabled on this peer + type: boolean + example: true + rosenpass_permissive: + description: Indicates whether Rosenpass is in permissive mode or not + type: boolean + example: false + server_ssh_allowed: + description: Indicates whether SSH access this peer is allowed or not + type: boolean + example: true + disable_client_routes: + description: Indicates whether client routes are disabled on this peer or not + type: boolean + example: false + disable_server_routes: + description: Indicates whether server routes are disabled on this peer or not + type: boolean + example: false + disable_dns: + description: Indicates whether DNS management is disabled on this peer or not + type: boolean + example: false + disable_firewall: + description: Indicates whether firewall management is disabled on this peer or not + type: boolean + example: false + block_lan_access: + description: Indicates whether LAN access is blocked on this peer when used as a routing peer + type: boolean + example: false + block_inbound: + description: Indicates whether inbound traffic is blocked on this peer + type: boolean + example: false + lazy_connection_enabled: + description: Indicates whether lazy connection is enabled on this peer + type: boolean + example: false PeerTemporaryAccessRequest: type: object properties: @@ -936,7 +981,7 @@ components: protocol: description: Policy rule type of the traffic type: string - enum: ["all", "tcp", "udp", "icmp"] + enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"] example: "tcp" ports: description: Policy rule affected ports @@ -949,6 +994,14 @@ components: type: array items: $ref: '#/components/schemas/RulePortRange' + authorized_groups: + description: Map of user group ids to a list of local users + type: object + additionalProperties: + type: array + items: + type: string + example: "group1" required: - name - enabled diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 74d9ca92b..f242f5a18 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -130,10 +130,11 @@ const ( // Defines values for PolicyRuleProtocol. const ( - PolicyRuleProtocolAll PolicyRuleProtocol = "all" - PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" - PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" - PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" + PolicyRuleProtocolAll PolicyRuleProtocol = "all" + PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" + PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh" + PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" + PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" ) // Defines values for PolicyRuleMinimumAction. @@ -144,10 +145,11 @@ const ( // Defines values for PolicyRuleMinimumProtocol. const ( - PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" - PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" - PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" - PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" + PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" + PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" + PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh" + PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" + PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" ) // Defines values for PolicyRuleUpdateAction. @@ -158,10 +160,11 @@ const ( // Defines values for PolicyRuleUpdateProtocol. const ( - PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" - PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" - PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" - PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" + PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" + PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" + PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh" + PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" + PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) // Defines values for ResourceType. @@ -1077,7 +1080,8 @@ type Peer struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1167,7 +1171,8 @@ type PeerBatch struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1197,6 +1202,39 @@ type PeerBatch struct { Version string `json:"version"` } +// PeerLocalFlags defines model for PeerLocalFlags. +type PeerLocalFlags struct { + // BlockInbound Indicates whether inbound traffic is blocked on this peer + BlockInbound *bool `json:"block_inbound,omitempty"` + + // BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer + BlockLanAccess *bool `json:"block_lan_access,omitempty"` + + // DisableClientRoutes Indicates whether client routes are disabled on this peer or not + DisableClientRoutes *bool `json:"disable_client_routes,omitempty"` + + // DisableDns Indicates whether DNS management is disabled on this peer or not + DisableDns *bool `json:"disable_dns,omitempty"` + + // DisableFirewall Indicates whether firewall management is disabled on this peer or not + DisableFirewall *bool `json:"disable_firewall,omitempty"` + + // DisableServerRoutes Indicates whether server routes are disabled on this peer or not + DisableServerRoutes *bool `json:"disable_server_routes,omitempty"` + + // LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer + LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + + // RosenpassEnabled Indicates whether Rosenpass is enabled on this peer + RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"` + + // RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not + RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"` + + // ServerSshAllowed Indicates whether SSH access this peer is allowed or not + ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"` +} + // PeerMinimum defines model for PeerMinimum. type PeerMinimum struct { // Id Peer ID @@ -1349,6 +1387,9 @@ type PolicyRule struct { // Action Policy rule accept or drops packets Action PolicyRuleAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct { // Action Policy rule accept or drops packets Action PolicyRuleMinimumAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct { // Action Policy rule accept or drops packets Action PolicyRuleUpdateAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 217b33db9..2047c51ea 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v6.33.1 // source: management.proto package proto @@ -267,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25, 0} + return file_management_proto_rawDescGZIP(), []int{27, 0} } type EncryptedMessage struct { @@ -1942,6 +1942,8 @@ type NetworkMap struct { // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` ForwardingRules []*ForwardingRule `protobuf:"bytes,12,rep,name=forwardingRules,proto3" json:"forwardingRules,omitempty"` + // SSHAuth represents SSH authorization configuration + SshAuth *SSHAuth `protobuf:"bytes,13,opt,name=sshAuth,proto3" json:"sshAuth,omitempty"` } func (x *NetworkMap) Reset() { @@ -2060,6 +2062,126 @@ func (x *NetworkMap) GetForwardingRules() []*ForwardingRule { return nil } +func (x *NetworkMap) GetSshAuth() *SSHAuth { + if x != nil { + return x.SshAuth + } + return nil +} + +type SSHAuth struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // UserIDClaim is the JWT claim to be used to get the users ID + UserIDClaim string `protobuf:"bytes,1,opt,name=UserIDClaim,proto3" json:"UserIDClaim,omitempty"` + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + AuthorizedUsers [][]byte `protobuf:"bytes,2,rep,name=AuthorizedUsers,proto3" json:"AuthorizedUsers,omitempty"` + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + MachineUsers map[string]*MachineUserIndexes `protobuf:"bytes,3,rep,name=machine_users,json=machineUsers,proto3" json:"machine_users,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *SSHAuth) Reset() { + *x = SSHAuth{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SSHAuth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHAuth) ProtoMessage() {} + +func (x *SSHAuth) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[22] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHAuth.ProtoReflect.Descriptor instead. +func (*SSHAuth) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{22} +} + +func (x *SSHAuth) GetUserIDClaim() string { + if x != nil { + return x.UserIDClaim + } + return "" +} + +func (x *SSHAuth) GetAuthorizedUsers() [][]byte { + if x != nil { + return x.AuthorizedUsers + } + return nil +} + +func (x *SSHAuth) GetMachineUsers() map[string]*MachineUserIndexes { + if x != nil { + return x.MachineUsers + } + return nil +} + +type MachineUserIndexes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Indexes []uint32 `protobuf:"varint,1,rep,packed,name=indexes,proto3" json:"indexes,omitempty"` +} + +func (x *MachineUserIndexes) Reset() { + *x = MachineUserIndexes{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MachineUserIndexes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MachineUserIndexes) ProtoMessage() {} + +func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MachineUserIndexes.ProtoReflect.Descriptor instead. +func (*MachineUserIndexes) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{23} +} + +func (x *MachineUserIndexes) GetIndexes() []uint32 { + if x != nil { + return x.Indexes + } + return nil +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -2081,7 +2203,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2094,7 +2216,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2107,7 +2229,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -2162,7 +2284,7 @@ type SSHConfig struct { func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2175,7 +2297,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2188,7 +2310,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2222,7 +2344,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2235,7 +2357,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2248,7 +2370,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{26} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2267,7 +2389,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2280,7 +2402,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2293,7 +2415,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2320,7 +2442,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2333,7 +2455,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2346,7 +2468,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{28} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2363,7 +2485,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2376,7 +2498,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2389,7 +2511,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2435,7 +2557,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2448,7 +2570,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2461,7 +2583,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *ProviderConfig) GetClientID() string { @@ -2569,7 +2691,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2582,7 +2704,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2595,7 +2717,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *Route) GetID() string { @@ -2684,7 +2806,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2697,7 +2819,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2710,7 +2832,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2757,7 +2879,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2770,7 +2892,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2783,7 +2905,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *CustomZone) GetDomain() string { @@ -2830,7 +2952,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2843,7 +2965,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2856,7 +2978,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *SimpleRecord) GetName() string { @@ -2909,7 +3031,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2922,7 +3044,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2935,7 +3057,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2980,7 +3102,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2993,7 +3115,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3006,7 +3128,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *NameServer) GetIP() string { @@ -3049,7 +3171,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3062,7 +3184,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3075,7 +3197,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{37} } func (x *FirewallRule) GetPeerIP() string { @@ -3139,7 +3261,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3152,7 +3274,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3165,7 +3287,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *NetworkAddress) GetNetIP() string { @@ -3193,7 +3315,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3206,7 +3328,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3219,7 +3341,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *Checks) GetFiles() []string { @@ -3244,7 +3366,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3257,7 +3379,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3270,7 +3392,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{40} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3341,7 +3463,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3354,7 +3476,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[41] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3367,7 +3489,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{39} + return file_management_proto_rawDescGZIP(), []int{41} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3458,7 +3580,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3471,7 +3593,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[42] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3484,7 +3606,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{40} + return file_management_proto_rawDescGZIP(), []int{42} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3527,7 +3649,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[41] + mi := &file_management_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3540,7 +3662,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[41] + mi := &file_management_proto_msgTypes[44] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3553,7 +3675,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38, 0} + return file_management_proto_rawDescGZIP(), []int{40, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3839,7 +3961,7 @@ var file_management_proto_rawDesc = []byte{ 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, - 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, + 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, @@ -3883,264 +4005,286 @@ var file_management_proto_rawDesc = []byte{ 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, - 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, - 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, - 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, - 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, - 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, - 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, - 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, - 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, - 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, - 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, - 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, - 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, - 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, - 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, - 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, - 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, - 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, - 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, - 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, - 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, - 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, - 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, - 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, - 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, - 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, - 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, - 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, - 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, - 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, - 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, - 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, - 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, - 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, - 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, - 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, - 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, - 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, + 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, + 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, + 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, + 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, + 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, + 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, + 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, + 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, + 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, + 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, + 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, + 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, + 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, + 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, + 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, + 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, + 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, + 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, + 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, - 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, - 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, - 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, - 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, - 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, - 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, - 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, + 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, + 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4156,7 +4300,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 42) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 45) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -4185,106 +4329,112 @@ var file_management_proto_goTypes = []interface{}{ (*PeerConfig)(nil), // 24: management.PeerConfig (*AutoUpdateSettings)(nil), // 25: management.AutoUpdateSettings (*NetworkMap)(nil), // 26: management.NetworkMap - (*RemotePeerConfig)(nil), // 27: management.RemotePeerConfig - (*SSHConfig)(nil), // 28: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 29: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 30: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 31: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 32: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 33: management.ProviderConfig - (*Route)(nil), // 34: management.Route - (*DNSConfig)(nil), // 35: management.DNSConfig - (*CustomZone)(nil), // 36: management.CustomZone - (*SimpleRecord)(nil), // 37: management.SimpleRecord - (*NameServerGroup)(nil), // 38: management.NameServerGroup - (*NameServer)(nil), // 39: management.NameServer - (*FirewallRule)(nil), // 40: management.FirewallRule - (*NetworkAddress)(nil), // 41: management.NetworkAddress - (*Checks)(nil), // 42: management.Checks - (*PortInfo)(nil), // 43: management.PortInfo - (*RouteFirewallRule)(nil), // 44: management.RouteFirewallRule - (*ForwardingRule)(nil), // 45: management.ForwardingRule - (*PortInfo_Range)(nil), // 46: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 48: google.protobuf.Duration + (*SSHAuth)(nil), // 27: management.SSHAuth + (*MachineUserIndexes)(nil), // 28: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 29: management.RemotePeerConfig + (*SSHConfig)(nil), // 30: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 31: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 32: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 33: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 34: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 35: management.ProviderConfig + (*Route)(nil), // 36: management.Route + (*DNSConfig)(nil), // 37: management.DNSConfig + (*CustomZone)(nil), // 38: management.CustomZone + (*SimpleRecord)(nil), // 39: management.SimpleRecord + (*NameServerGroup)(nil), // 40: management.NameServerGroup + (*NameServer)(nil), // 41: management.NameServer + (*FirewallRule)(nil), // 42: management.FirewallRule + (*NetworkAddress)(nil), // 43: management.NetworkAddress + (*Checks)(nil), // 44: management.Checks + (*PortInfo)(nil), // 45: management.PortInfo + (*RouteFirewallRule)(nil), // 46: management.RouteFirewallRule + (*ForwardingRule)(nil), // 47: management.ForwardingRule + nil, // 48: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 49: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 50: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 51: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 27, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 29, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig 26, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 42, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 44, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 41, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 43, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 42, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 47, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 44, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 50, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 48, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 51, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 28, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 30, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig 25, // 26: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 27, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 34, // 29: management.NetworkMap.Routes:type_name -> management.Route - 35, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 27, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 40, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 44, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 45, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 28, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 33, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 33, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 38, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 36, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 37, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord - 39, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 43, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 46, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 43, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 43, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 43, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 70: management.ManagementService.Logout:output_type -> management.Empty - 63, // [63:71] is the sub-list for method output_type - 55, // [55:63] is the sub-list for method input_type - 55, // [55:55] is the sub-list for extension type_name - 55, // [55:55] is the sub-list for extension extendee - 0, // [0:55] is the sub-list for field type_name + 29, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 36, // 29: management.NetworkMap.Routes:type_name -> management.Route + 37, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 29, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 42, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 46, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 47, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 35: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 48, // 36: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 30, // 37: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 38: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 39: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 35, // 40: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 35, // 41: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 40, // 42: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 38, // 43: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 39, // 44: management.CustomZone.Records:type_name -> management.SimpleRecord + 41, // 45: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 46: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 47: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 48: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 45, // 49: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 49, // 50: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 51: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 52: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 45, // 53: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 54: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 45, // 55: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 45, // 56: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 28, // 57: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 5, // 58: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 59: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 60: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 61: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 62: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 65: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 66: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 67: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 68: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 69: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 70: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 71: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 72: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 73: management.ManagementService.Logout:output_type -> management.Empty + 66, // [66:74] is the sub-list for method output_type + 58, // [58:66] is the sub-list for method input_type + 58, // [58:58] is the sub-list for extension type_name + 58, // [58:58] is the sub-list for extension extendee + 0, // [0:58] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4558,7 +4708,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*SSHAuth); i { case 0: return &v.state case 1: @@ -4570,7 +4720,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*MachineUserIndexes); i { case 0: return &v.state case 1: @@ -4582,7 +4732,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4594,7 +4744,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4606,7 +4756,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4618,7 +4768,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4630,7 +4780,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4642,7 +4792,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4654,7 +4804,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4666,7 +4816,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4678,7 +4828,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4690,7 +4840,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4702,7 +4852,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4714,7 +4864,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4726,7 +4876,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4738,7 +4888,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4750,7 +4900,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4762,7 +4912,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -4774,7 +4924,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*PortInfo); i { case 0: return &v.state case 1: @@ -4786,6 +4936,30 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[44].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4798,7 +4972,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[38].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[40].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4808,7 +4982,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 42, + NumMessages: 45, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index fec51ca91..f2e591e88 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -332,6 +332,24 @@ message NetworkMap { bool routesFirewallRulesIsEmpty = 11; repeated ForwardingRule forwardingRules = 12; + + // SSHAuth represents SSH authorization configuration + SSHAuth sshAuth = 13; +} + +message SSHAuth { + // UserIDClaim is the JWT claim to be used to get the users ID + string UserIDClaim = 1; + + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + repeated bytes AuthorizedUsers = 2; + + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + map machine_users = 3; +} + +message MachineUserIndexes { + repeated uint32 indexes = 1; } // RemotePeerConfig represents a configuration of a remote peer. diff --git a/shared/sshauth/userhash.go b/shared/sshauth/userhash.go new file mode 100644 index 000000000..276fc9ba2 --- /dev/null +++ b/shared/sshauth/userhash.go @@ -0,0 +1,28 @@ +package sshauth + +import ( + "encoding/hex" + + "golang.org/x/crypto/blake2b" +) + +// UserIDHash represents a hashed user ID (BLAKE2b-128) +type UserIDHash [16]byte + +// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value +// This function must produce the same hash on both client and management server +func HashUserID(userID string) (UserIDHash, error) { + hash, err := blake2b.New(16, nil) + if err != nil { + return UserIDHash{}, err + } + hash.Write([]byte(userID)) + var result UserIDHash + copy(result[:], hash.Sum(nil)) + return result, nil +} + +// String returns the hexadecimal string representation of the hash +func (h UserIDHash) String() string { + return hex.EncodeToString(h[:]) +} diff --git a/shared/sshauth/userhash_test.go b/shared/sshauth/userhash_test.go new file mode 100644 index 000000000..5a3cb6986 --- /dev/null +++ b/shared/sshauth/userhash_test.go @@ -0,0 +1,210 @@ +package sshauth + +import ( + "testing" +) + +func TestHashUserID(t *testing.T) { + tests := []struct { + name string + userID string + }{ + { + name: "simple user ID", + userID: "user@example.com", + }, + { + name: "UUID format", + userID: "550e8400-e29b-41d4-a716-446655440000", + }, + { + name: "numeric ID", + userID: "12345", + }, + { + name: "empty string", + userID: "", + }, + { + name: "special characters", + userID: "user+test@domain.com", + }, + { + name: "unicode characters", + userID: "用户@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v, want nil", err) + return + } + + // Verify hash is non-zero for non-empty inputs + if tt.userID != "" && hash == [16]byte{} { + t.Errorf("HashUserID() returned zero hash for non-empty input") + } + }) + } +} + +func TestHashUserID_Consistency(t *testing.T) { + userID := "test@example.com" + + hash1, err1 := HashUserID(userID) + if err1 != nil { + t.Fatalf("First HashUserID() error = %v", err1) + } + + hash2, err2 := HashUserID(userID) + if err2 != nil { + t.Fatalf("Second HashUserID() error = %v", err2) + } + + if hash1 != hash2 { + t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2) + } +} + +func TestHashUserID_Uniqueness(t *testing.T) { + tests := []struct { + userID1 string + userID2 string + }{ + {"user1@example.com", "user2@example.com"}, + {"alice@domain.com", "bob@domain.com"}, + {"test", "test1"}, + {"", "a"}, + } + + for _, tt := range tests { + hash1, err1 := HashUserID(tt.userID1) + if err1 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1) + } + + hash2, err2 := HashUserID(tt.userID2) + if err2 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2) + } + + if hash1 == hash2 { + t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1) + } + } +} + +func TestUserIDHash_String(t *testing.T) { + tests := []struct { + name string + hash UserIDHash + expected string + }{ + { + name: "zero hash", + hash: [16]byte{}, + expected: "00000000000000000000000000000000", + }, + { + name: "small value", + hash: [16]byte{15: 0xff}, + expected: "000000000000000000000000000000ff", + }, + { + name: "large value", + hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}, + expected: "0000000000000000deadbeefcafebabe", + }, + { + name: "max value", + hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expected: "ffffffffffffffffffffffffffffffff", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.hash.String() + if result != tt.expected { + t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUserIDHash_String_Length(t *testing.T) { + // Test that String() always returns 32 hex characters (16 bytes * 2) + userID := "test@example.com" + hash, err := HashUserID(userID) + if err != nil { + t.Fatalf("HashUserID() error = %v", err) + } + + result := hash.String() + if len(result) != 32 { + t.Errorf("UserIDHash.String() length = %d, want 32", len(result)) + } + + // Verify it's valid hex + for i, c := range result { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c) + } + } +} + +func TestHashUserID_KnownValues(t *testing.T) { + // Test with known BLAKE2b-128 values to ensure correct implementation + tests := []struct { + name string + userID string + expected UserIDHash + }{ + { + name: "empty string", + userID: "", + // BLAKE2b-128 of empty string + expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70}, + }, + { + name: "single character 'a'", + userID: "a", + // BLAKE2b-128 of "a" + expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v", err) + return + } + + if hash != tt.expected { + t.Errorf("HashUserID(%q) = %x, want %x", + tt.userID, hash, tt.expected) + } + }) + } +} + +func BenchmarkHashUserID(b *testing.B) { + userID := "user@example.com" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = HashUserID(userID) + } +} + +func BenchmarkUserIDHash_String(b *testing.B) { + hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = hash.String() + } +} From b6a327e0c96dbea034422f41c18112c27477ec60 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:03:16 +0100 Subject: [PATCH 116/118] [management] fix scanning authorized user on policy rule (#5002) --- management/server/store/sql_store.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index da05ba416..d2220d4b4 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1919,7 +1919,8 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t var r types.PolicyRule var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser) + var authorizedUser sql.NullString + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser) if err == nil { if enabled.Valid { r.Enabled = enabled.Bool @@ -1948,6 +1949,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t if authorizedGroups != nil { _ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups) } + if authorizedUser.Valid { + r.AuthorizedUser = authorizedUser.String + } } return &r, err }) From 38f9d5ed58194d23453bb094550c50a99ba42408 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 29 Dec 2025 16:07:06 +0100 Subject: [PATCH 117/118] [infra] Preset signal port on templates (#5004) When passing certificates to signal, it will select port 443 when no port is supplied. This changes forces port 80. --- infrastructure_files/docker-compose.yml.tmpl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 2bc49d3e5..1c9c63f78 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -53,7 +53,8 @@ services: command: [ "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", - "--log-file", "console" + "--log-file", "console", + "--port", "80" ] # Relay From e11970e32e174a3020a4883aef590c4904d7519b Mon Sep 17 00:00:00 2001 From: Louis Li <32395144+gamerslouis@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:37:49 +0800 Subject: [PATCH 118/118] [client] add reset for management backoff (#4935) Reset client management grpc client backoff after successful connected to management API. Current Situation: If the connection duration exceeds MaxElapsedTime, when the connection is interrupted, the backoff fails immediately due to timeout and does not actually perform a retry. --- shared/management/client/grpc.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 520a83e36..89860ac9b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -111,6 +111,8 @@ func (c *GrpcClient) ready() bool { // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { + backOff := defaultBackoff(ctx) + operation := func() error { log.Debugf("management connection state %v", c.conn.GetState()) connState := c.conn.GetState() @@ -128,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return err } - return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) + return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff) } - err := backoff.Retry(operation, defaultBackoff(ctx)) + err := backoff.Retry(operation, backOff) if err != nil { log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err) } @@ -140,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler } func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, - msgHandler func(msg *proto.SyncResponse) error) error { + msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error { ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() @@ -158,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, // blocking until error err = c.receiveEvents(stream, serverPubKey, msgHandler) + // we need this reset because after a successful connection and a consequent error, backoff lib doesn't + // reset times and next try will start with a long delay + backOff.Reset() if err != nil { c.notifyDisconnected(err) s, _ := gstatus.FromError(err)