mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
10 Commits
refactor/r
...
feature/st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df101bf071 | ||
|
|
8393bf1b17 | ||
|
|
02a04958e7 | ||
|
|
000e99e7f3 | ||
|
|
0d2e67983a | ||
|
|
5151f19d29 | ||
|
|
bedd3cabc9 | ||
|
|
d35a845dbd | ||
|
|
4e03f708a4 | ||
|
|
654aa9581d |
@@ -31,6 +31,7 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
|
||||
@@ -40,7 +40,6 @@ type Manager struct {
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
|
||||
return listenPort
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
func SetListenPort(port uint16) {
|
||||
listenPortMu.Lock()
|
||||
listenPort = port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
|
||||
@@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder(
|
||||
return
|
||||
}
|
||||
|
||||
if forwarderPort > 0 {
|
||||
dnsfwd.SetListenPort(forwarderPort)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder(
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
@@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
|
||||
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -19,18 +21,34 @@ const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
@@ -11,11 +12,17 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -49,6 +49,7 @@ services:
|
||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.routers.netbird-signal.service=netbird-signal
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,16 @@ package integrated_validator
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
|
||||
@@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
@@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers && userID != activity.SystemInitiator {
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -584,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
@@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
@@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
|
||||
}
|
||||
|
||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||
// Returns a slice of functions to save events after successful peer deletion.
|
||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||
|
||||
@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg) //
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
129
management/server/store/cache/dual_key_cache.go
vendored
Normal file
129
management/server/store/cache/dual_key_cache.go
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DualKeyCache provides a caching mechanism where each entry has two keys:
|
||||
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
|
||||
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
|
||||
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
primaryIndex map[K1]V // Primary key -> Value
|
||||
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
|
||||
reverseLookup map[K1]K2 // Primary key -> Secondary key
|
||||
}
|
||||
|
||||
// NewDualKeyCache creates a new dual-key cache
|
||||
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
|
||||
return &DualKeyCache[K1, K2, V]{
|
||||
primaryIndex: make(map[K1]V),
|
||||
secondaryIndex: make(map[K2]map[K1]struct{}),
|
||||
reverseLookup: make(map[K1]K2),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the primary key
|
||||
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.primaryIndex[primaryKey]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with both primary and secondary keys
|
||||
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, oldSecondaryKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.primaryIndex[primaryKey] = value
|
||||
c.reverseLookup[primaryKey] = secondaryKey
|
||||
|
||||
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
|
||||
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
|
||||
}
|
||||
|
||||
// InvalidateByPrimaryKey removes an entry using the primary key
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
}
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
}
|
||||
|
||||
// InvalidateBySecondaryKey removes all entries with the given secondary key
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.secondaryIndex[secondaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.primaryIndex = make(map[K1]V)
|
||||
c.secondaryIndex = make(map[K2]map[K1]struct{})
|
||||
c.reverseLookup = make(map[K1]K2)
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *DualKeyCache[K1, K2, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.primaryIndex)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
// The loadFunc should return both the value and the secondary key (extracted from the value)
|
||||
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, primaryKey); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, secondaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
77
management/server/store/cache/single_key_cache.go
vendored
Normal file
77
management/server/store/cache/single_key_cache.go
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SingleKeyCache provides a simple caching mechanism with a single key
|
||||
type SingleKeyCache[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
cache map[K]V // Key -> Value
|
||||
}
|
||||
|
||||
// NewSingleKeyCache creates a new single-key cache
|
||||
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
|
||||
return &SingleKeyCache[K, V]{
|
||||
cache: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the key
|
||||
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.cache[key]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the given key
|
||||
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[key] = value
|
||||
}
|
||||
|
||||
// Invalidate removes an entry using the key
|
||||
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache = make(map[K]V)
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *SingleKeyCache[K, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, key); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, key, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
242
management/server/store/cache/triple_key_cache.go
vendored
Normal file
242
management/server/store/cache/triple_key_cache.go
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TripleKeyCache provides a caching mechanism where each entry has three keys:
|
||||
// - Primary key (K1): used for accessing and invalidating specific entries
|
||||
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
|
||||
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
|
||||
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
primaryIndex map[K1]V // Primary key -> Value
|
||||
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
|
||||
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
|
||||
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
|
||||
}
|
||||
|
||||
type keyPair[K2 comparable, K3 comparable] struct {
|
||||
secondary K2
|
||||
tertiary K3
|
||||
}
|
||||
|
||||
// NewTripleKeyCache creates a new triple-key cache
|
||||
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
|
||||
return &TripleKeyCache[K1, K2, K3, V]{
|
||||
primaryIndex: make(map[K1]V),
|
||||
secondaryIndex: make(map[K2]map[K1]struct{}),
|
||||
tertiaryIndex: make(map[K3]map[K1]struct{}),
|
||||
reverseLookup: make(map[K1]keyPair[K2, K3]),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache using the primary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
value, ok := c.primaryIndex[primaryKey]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with primary, secondary, and tertiary keys
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, oldKeys.secondary)
|
||||
}
|
||||
}
|
||||
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, oldKeys.tertiary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.primaryIndex[primaryKey] = value
|
||||
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
|
||||
secondary: secondaryKey,
|
||||
tertiary: tertiaryKey,
|
||||
}
|
||||
|
||||
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
|
||||
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
|
||||
|
||||
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
|
||||
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
|
||||
}
|
||||
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
|
||||
}
|
||||
|
||||
// InvalidateByPrimaryKey removes an entry using the primary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if keys, exists := c.reverseLookup[primaryKey]; exists {
|
||||
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, keys.secondary)
|
||||
}
|
||||
}
|
||||
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
|
||||
delete(primaryKeys, primaryKey)
|
||||
if len(primaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, keys.tertiary)
|
||||
}
|
||||
}
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
}
|
||||
|
||||
// InvalidateBySecondaryKey removes all entries with the given secondary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.secondaryIndex[secondaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
if keys, ok := c.reverseLookup[primaryKey]; ok {
|
||||
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
|
||||
delete(tertiaryPrimaryKeys, primaryKey)
|
||||
if len(tertiaryPrimaryKeys) == 0 {
|
||||
delete(c.tertiaryIndex, keys.tertiary)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.secondaryIndex, secondaryKey)
|
||||
}
|
||||
|
||||
// InvalidateByTertiaryKey removes all entries with the given tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for primaryKey := range primaryKeys {
|
||||
if keys, ok := c.reverseLookup[primaryKey]; ok {
|
||||
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
|
||||
delete(secondaryPrimaryKeys, primaryKey)
|
||||
if len(secondaryPrimaryKeys) == 0 {
|
||||
delete(c.secondaryIndex, keys.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(c.primaryIndex, primaryKey)
|
||||
delete(c.reverseLookup, primaryKey)
|
||||
}
|
||||
|
||||
delete(c.tertiaryIndex, tertiaryKey)
|
||||
}
|
||||
|
||||
// InvalidateAll removes all entries from the cache
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.primaryIndex = make(map[K1]V)
|
||||
c.secondaryIndex = make(map[K2]map[K1]struct{})
|
||||
c.tertiaryIndex = make(map[K3]map[K1]struct{})
|
||||
c.reverseLookup = make(map[K1]keyPair[K2, K3])
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the cache
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.primaryIndex)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
|
||||
if value, ok := c.Get(ctx, primaryKey); ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
value, secondaryKey, tertiaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, primary key, secondary key, and tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
|
||||
c.mu.RLock()
|
||||
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
|
||||
for primaryKey := range primaryKeys {
|
||||
if value, ok := c.primaryIndex[primaryKey]; ok {
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, primaryKey, tertiaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
|
||||
// The loadFunc should return the value, primary key, secondary key, and tertiary key
|
||||
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
|
||||
c.mu.RLock()
|
||||
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
|
||||
for primaryKey := range primaryKeys {
|
||||
if value, ok := c.primaryIndex[primaryKey]; ok {
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, primaryKey, secondaryKey, err := loadFunc()
|
||||
if err != nil {
|
||||
var zero V
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
@@ -14,6 +15,8 @@ type StoreMetrics struct {
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
transactionDurationMs metric.Int64Histogram
|
||||
queryDurationMs metric.Int64Histogram
|
||||
queryCounter metric.Int64Counter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -59,12 +62,29 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of database query operations with operation type and table name"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryCounter, err := meter.Int64Counter("management.store.query.count",
|
||||
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StoreMetrics{
|
||||
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
|
||||
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
|
||||
persistenceDurationMicro: persistenceDurationMicro,
|
||||
persistenceDurationMs: persistenceDurationMs,
|
||||
transactionDurationMs: transactionDurationMs,
|
||||
queryDurationMs: queryDurationMs,
|
||||
queryCounter: queryCounter,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -85,3 +105,13 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
|
||||
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
|
||||
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// CountStoreOperation records a store operation with its method name, status, and duration
|
||||
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("method", method),
|
||||
}
|
||||
|
||||
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
|
||||
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package version
|
||||
|
||||
import "golang.org/x/sys/windows/registry"
|
||||
import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
urlWinExe = "https://pkgs.netbird.io/windows/x64"
|
||||
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
|
||||
)
|
||||
|
||||
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
|
||||
@@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
|
||||
// DownloadUrl return with the proper download link
|
||||
func DownloadUrl() string {
|
||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
return urlWinExe
|
||||
} else {
|
||||
if err != nil {
|
||||
return downloadURL
|
||||
}
|
||||
|
||||
url := urlWinExe
|
||||
if runtime.GOARCH == "arm64" {
|
||||
url = urlWinExeArm
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user