Compare commits

...

10 Commits

Author SHA1 Message Date
Pascal Fischer
df101bf071 add custom store cache 2025-10-17 15:57:40 +02:00
Pascal Fischer
8393bf1b17 pass metrics in transactions 2025-10-15 16:29:36 +02:00
Pascal Fischer
02a04958e7 add metrics to store methods 2025-10-14 21:28:52 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
19 changed files with 1766 additions and 374 deletions

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var (
searchDomains []string
matchDomains []string

View File

@@ -40,7 +40,6 @@ type Manager struct {
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
port uint16
}
func ListenPort() uint16 {
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
port: port,
}
}
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {

View File

@@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder(
return
}
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder(
if len(fwdEntries) > 0 {
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
@@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -6,11 +6,13 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,6 +3,7 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -49,6 +49,7 @@ services:
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -3,16 +3,16 @@ package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error

View File

@@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
var peer *nbpeer.Peer
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if updateAccountPeers && userID != activity.SystemInitiator {
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -584,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
@@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
@@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {

View File

@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg) //
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -0,0 +1,129 @@
package cache
import (
"context"
"sync"
)
// DualKeyCache provides a caching mechanism where each entry has two keys:
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
reverseLookup map[K1]K2 // Primary key -> Secondary key
}
// NewDualKeyCache creates a new dual-key cache
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
return &DualKeyCache[K1, K2, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
reverseLookup: make(map[K1]K2),
}
}
// Get retrieves a value from the cache using the primary key
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with both primary and secondary keys
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldSecondaryKey)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = secondaryKey
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, secondaryKey)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.reverseLookup = make(map[K1]K2)
}
// Size returns the number of entries in the cache
func (c *DualKeyCache[K1, K2, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return both the value and the secondary key (extracted from the value)
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, value)
return value, nil
}

View File

@@ -0,0 +1,77 @@
package cache
import (
"context"
"sync"
)
// SingleKeyCache provides a simple caching mechanism with a single key
type SingleKeyCache[K comparable, V any] struct {
mu sync.RWMutex
cache map[K]V // Key -> Value
}
// NewSingleKeyCache creates a new single-key cache
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
return &SingleKeyCache[K, V]{
cache: make(map[K]V),
}
}
// Get retrieves a value from the cache using the key
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.cache[key]
return value, ok
}
// Set stores a value in the cache with the given key
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = value
}
// Invalidate removes an entry using the key
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidateAll removes all entries from the cache
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]V)
}
// Size returns the number of entries in the cache
func (c *SingleKeyCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
if value, ok := c.Get(ctx, key); ok {
return value, nil
}
value, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, key, value)
return value, nil
}

View File

@@ -0,0 +1,242 @@
package cache
import (
"context"
"sync"
)
// TripleKeyCache provides a caching mechanism where each entry has three keys:
// - Primary key (K1): used for accessing and invalidating specific entries
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
}
type keyPair[K2 comparable, K3 comparable] struct {
secondary K2
tertiary K3
}
// NewTripleKeyCache creates a new triple-key cache
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
return &TripleKeyCache[K1, K2, K3, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
tertiaryIndex: make(map[K3]map[K1]struct{}),
reverseLookup: make(map[K1]keyPair[K2, K3]),
}
}
// Get retrieves a value from the cache using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with primary, secondary, and tertiary keys
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldKeys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, oldKeys.tertiary)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
secondary: secondaryKey,
tertiary: tertiaryKey,
}
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
}
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if keys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
delete(tertiaryPrimaryKeys, primaryKey)
if len(tertiaryPrimaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateByTertiaryKey removes all entries with the given tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
delete(secondaryPrimaryKeys, primaryKey)
if len(secondaryPrimaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.tertiaryIndex, tertiaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.tertiaryIndex = make(map[K3]map[K1]struct{})
c.reverseLookup = make(map[K1]keyPair[K2, K3])
}
// Size returns the number of entries in the cache
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -14,6 +15,8 @@ type StoreMetrics struct {
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
queryDurationMs metric.Int64Histogram
queryCounter metric.Int64Counter
ctx context.Context
}
@@ -59,12 +62,29 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of database query operations with operation type and table name"),
)
if err != nil {
return nil, err
}
queryCounter, err := meter.Int64Counter("management.store.query.count",
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
)
if err != nil {
return nil, err
}
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
queryDurationMs: queryDurationMs,
queryCounter: queryCounter,
ctx: ctx,
}, nil
}
@@ -85,3 +105,13 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
// CountStoreOperation records a store operation with its method name, status, and duration
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
attrs := []attribute.KeyValue{
attribute.String("method", method),
}
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
}

View File

@@ -1,9 +1,13 @@
package version
import "golang.org/x/sys/windows/registry"
import (
"golang.org/x/sys/windows/registry"
"runtime"
)
const (
urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
)
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
@@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
// DownloadUrl return with the proper download link
func DownloadUrl() string {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
if err == nil {
return urlWinExe
} else {
if err != nil {
return downloadURL
}
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
return url
}