Compare commits

...

30 Commits

Author SHA1 Message Date
Zoltán Papp
c495eaa549 Move interface near to engine 2024-10-10 14:46:51 +02:00
Zoltán Papp
b8026ad541 Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 23:24:07 +02:00
Maycon Santos
6ce09bca16 Add support to envsub go management configurations (#2708)
This change allows users to reference environment variables using Go template format, like {{ .EnvName }}

Moved the previous file test code to file_suite_test.go.
2024-10-09 20:46:23 +02:00
pascal-fischer
b79c1d64cc [management] Make max open db conns configurable (#2713) 2024-10-09 20:17:25 +02:00
Zoltán Papp
a5deeda727 Revert force install change 2024-10-09 19:20:20 +02:00
Zoltán Papp
5b2d5f8df1 Try to force install libpcap 2024-10-09 19:12:33 +02:00
Zoltán Papp
6369706ade Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 18:54:30 +02:00
Misha Bragin
b1eda43f4b Add Link to the Lawrence Systems video (#2711) 2024-10-09 14:56:25 +02:00
pascal-fischer
d4ef84fe6e [management] Propagate error in store errors (#2709) 2024-10-09 14:33:58 +02:00
Zoltan Papp
e3dfbe5acf Add trace log 2024-10-09 14:07:35 +02:00
Zoltan Papp
deeb05047d Handle addr resolve error 2024-10-09 14:05:43 +02:00
Zoltan Papp
1814b07a4b Replace error check to errors.Is 2024-10-09 14:02:23 +02:00
Zoltán Papp
b04d19bb0a Fix nil pointer in error handling 2024-10-08 12:04:57 +02:00
Viktor Liu
44e8107383 [client] Limit P2P attempts and restart on specific events (#2657) 2024-10-08 11:21:11 +02:00
Bethuel Mmbaga
2c1f5e46d5 [management] Validate peer ownership during login (#2704)
* check peer ownership in login

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update error message

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-07 19:06:26 +03:00
Zoltán Papp
20815c9f90 Remove unused function 2024-10-07 13:28:21 +02:00
Zoltán Papp
ba3cdb30ee Remove unnecessary ctx cancel check 2024-10-07 13:05:11 +02:00
Zoltán Papp
1f25bb0751 Reducate cognitive complexity 2024-10-07 12:58:45 +02:00
Zoltán Papp
9e7aac3a56 Reducate cognitive complexity 2024-10-07 12:52:55 +02:00
Zoltán Papp
718d9526a7 Fix test 2024-10-07 12:45:21 +02:00
Zoltán Papp
48184ecf21 Fix eBPF pause handling 2024-10-07 12:40:53 +02:00
Zoltán Papp
f18ae8b925 Apply pause logic 2024-10-07 11:22:48 +02:00
Zoltán Papp
90d9dd4c08 Remove unused function from eBPF proxy 2024-10-07 10:35:53 +02:00
pascal-fischer
dbec24b520 [management] Remove admin check on getAccountByID (#2699) 2024-10-06 17:01:13 +02:00
Carlos Hernandez
f603cd9202 [client] Check wginterface instead of engine ctx (#2676)
Moving code to ensure wgInterface is gone right after context is
cancelled/stop in the off chance that on next retry the backoff
operation is permanently cancelled and interface is abandoned without
destroying.
2024-10-04 19:15:16 +02:00
Bethuel Mmbaga
5897a48e29 fix wrong reference (#2695)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:55:25 +03:00
Bethuel Mmbaga
8bf729c7b4 [management] Add AccountExists to AccountManager (#2694)
* Add AccountExists method to account manager interface

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:09:40 +03:00
Bethuel Mmbaga
7f09b39769 [management] Refactor User JWT group sync (#2690)
* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 17:17:01 +03:00
Zoltán Papp
acad98e328 Code cleaning 2024-10-03 02:29:46 +02:00
Zoltán Papp
9d75cc3273 Add pause function for proxies 2024-10-03 01:24:05 +02:00
41 changed files with 1501 additions and 589 deletions

View File

@@ -16,7 +16,7 @@ jobs:
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-latest runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5

View File

@@ -20,7 +20,7 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-22.04
env: env:
flags: "" flags: ""
steps: steps:

View File

@@ -49,6 +49,8 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -62,6 +64,7 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -269,12 +269,6 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock() c.engineMutex.Unlock()
@@ -294,6 +288,15 @@ func (c *ConnectClient) run(
} }
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()

View File

@@ -141,7 +141,7 @@ type Engine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgInterface iface.IWGIface wgInterface IWGIface
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
@@ -251,6 +251,13 @@ func (e *Engine) Stop() error {
} }
log.Info("Network monitor: stopped") log.Info("Network monitor: stopped")
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
@@ -319,7 +326,7 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface.(*iface.WGIface), e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
@@ -914,7 +921,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
WgListenPort: e.config.WgPort, WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface, WgInterface: e.wgInterface.(*iface.WGIface),
AllowedIps: allowedIPs, AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey, PreSharedKey: e.config.PreSharedKey,
} }
@@ -1116,18 +1123,12 @@ func (e *Engine) close() {
} }
} }
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
} }
e.wgInterface = nil
} }
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
@@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
} }
// Set a new timer to debounce rapid network changes // Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() { debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period // This function is called after the debounce period
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
} }
func (e *Engine) stopDNSServer() { func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return
}
e.dnsServer.Stop()
e.dnsServer = nil
err := fmt.Errorf("DNS server stopped") err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates() nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates { for i := range nsGroupStates {
@@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() {
nsGroupStates[i].Error = err nsGroupStates[i].Error = err
} }
e.statusRecorder.UpdateDNSStates(nsGroupStates) e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
} }
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.

View File

@@ -242,13 +242,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
nil) nil)
wgIface := &iface.MockWGIface{ wgIface := &MockWGIface{
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface.(*iface.WGIface), engine.statusRecorder, relayMgr, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }

View File

@@ -1,4 +1,4 @@
package iface package internal
import ( import (
"net" "net"

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package iface package internal
import ( import (
"net" "net"

View File

@@ -1,4 +1,4 @@
package iface package internal
import ( import (
"net" "net"

View File

@@ -32,12 +32,14 @@ const (
connPriorityRelay ConnPriority = 1 connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1 connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2 connPriorityICEP2P ConnPriority = 2
reconnectMaxElapsedTime = 30 * time.Minute
) )
type WgConfig struct { type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
WgInterface iface.IWGIface WgInterface *iface.WGIface
AllowedIps string AllowedIps string
PreSharedKey *wgtypes.Key PreSharedKey *wgtypes.Key
} }
@@ -80,9 +82,8 @@ type Conn struct {
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
allowedIPsIP string allowedIPsIP string
handshaker *Handshaker handshaker *Handshaker
@@ -103,11 +104,14 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations // for reconnection operations
iCEDisconnected chan bool iCEDisconnected chan bool
relayDisconnected chan bool relayDisconnected chan bool
connMonitor *ConnMonitor
reconnectCh <-chan struct{}
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@@ -123,21 +127,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
log: connLog, log: connLog,
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
config: config, config: config,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgProxyFactory: wgProxyFactory, wgProxyFactory: wgProxyFactory,
signaler: signaler, signaler: signaler,
relayManager: relayManager, iFaceDiscover: iFaceDiscover,
allowedIPsIP: allowedIPsIP.String(), relayManager: relayManager,
statusRelay: NewAtomicConnStatus(), allowedIPsIP: allowedIPsIP.String(),
statusICE: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1), iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1),
} }
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
signaler,
iFaceDiscover,
config,
conn.relayDisconnected,
conn.iCEDisconnected,
)
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady, OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected, OnDisconnected: conn.onWorkerRelayStateDisconnected,
@@ -200,6 +214,8 @@ func (conn *Conn) startHandshakeAndReconnect() {
conn.log.Errorf("failed to send initial offer: %v", err) conn.log.Errorf("failed to send initial offer: %v", err)
} }
go conn.connMonitor.Start(conn.ctx)
if conn.workerRelay.IsController() { if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry() conn.reconnectLoopWithRetry()
} else { } else {
@@ -240,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
@@ -309,12 +324,14 @@ func (conn *Conn) reconnectLoopWithRetry() {
// With it, we can decrease to send necessary offer // With it, we can decrease to send necessary offer
select { select {
case <-conn.ctx.Done(): case <-conn.ctx.Done():
return
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
} }
ticker := conn.prepareExponentTicker() ticker := conn.prepareExponentTicker()
defer ticker.Stop() defer ticker.Stop()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
for { for {
select { select {
case t := <-ticker.C: case t := <-ticker.C:
@@ -342,20 +359,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
if err != nil { if err != nil {
conn.log.Errorf("failed to do handshake: %v", err) conn.log.Errorf("failed to do handshake: %v", err)
} }
case changed := <-conn.relayDisconnected:
if !changed { case <-conn.reconnectCh:
continue
}
conn.log.Debugf("Relay state changed, reset reconnect timer")
ticker.Stop()
ticker = conn.prepareExponentTicker()
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, reset reconnect timer")
ticker.Stop() ticker.Stop()
ticker = conn.prepareExponentTicker() ticker = conn.prepareExponentTicker()
case <-conn.ctx.Done(): case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop") conn.log.Debugf("context is done, stop reconnect loop")
return return
@@ -366,10 +374,10 @@ func (conn *Conn) reconnectLoopWithRetry() {
func (conn *Conn) prepareExponentTicker() *backoff.Ticker { func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.01, RandomizationFactor: 0.1,
Multiplier: 2, Multiplier: 2,
MaxInterval: conn.config.Timeout, MaxInterval: conn.config.Timeout,
MaxElapsedTime: 0, MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, conn.ctx) }, conn.ctx)
@@ -420,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.log.Infof("set ICE to active connection")
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) var (
if err != nil { ep *net.UDPAddr
return wgProxy wgproxy.Proxy
err error
)
if iceConnInfo.RelayedOnLocal {
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
ep = directEp
} }
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) conn.log.Errorf("Before add peer hook failed: %v", err)
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
err = conn.configureWGEndpoint(endpointUdpAddr) if conn.wgProxyRelay != nil {
if err != nil { conn.wgProxyRelay.Pause()
if wgProxy != nil { }
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close turn connection: %v", err) if wgProxy != nil {
} wgProxy.Work()
} }
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyICE = wgProxy
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
@@ -482,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to %s", newState)
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection // switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay) conn.wgProxyRelay.Work()
if err != nil {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
@@ -496,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState) conn.statusICE.Set(newState)
select { conn.notifyReconnectLoopICEDisconnected(changed)
case conn.iCEDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -520,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection has been established, setup the WireGuard")
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy, err := conn.newProxy(rci.relayedConn)
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
if conn.currentConnPriority > connPriorityRelay { conn.wgProxyRelay = wgProxy
if conn.statusICE.Get() == StatusConnected { conn.statusRelay.Set(StatusConnected)
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
}
} }
conn.connIDRelay = nbnet.GenerateConnID() if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
for _, hook := range conn.beforeAddPeerHooks { conn.log.Errorf("Before add peer hook failed: %v", err)
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
err = conn.configureWGEndpoint(endpointUdpAddr) wgProxy.Work()
if err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update wg peer configuration: %v", err) conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
} }
@@ -587,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return return
} }
log.Debugf("relay connection is disconnected") conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == connPriorityRelay {
log.Debugf("clean up WireGuard config") conn.log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn() _ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed)
select {
case conn.relayDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -617,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(), Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
if err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
@@ -755,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() { func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
@@ -775,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
} }
} }
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
if !iceConnInfo.RelayedOnLocal { conn.log.Debugf("setup proxied WireGuard connection")
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil { return nil, err
conn.log.Warnf("failed to close turn proxy connection: %v", errClose) }
} return wgProxy, nil
return nil, nil, err }
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Work()
} }
return ep, wgProxy, nil
} }
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {

View File

@@ -0,0 +1,212 @@
package peer
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
signalerMonitorPeriod = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ConnMonitor struct {
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
config ConnConfig
relayDisconnected chan bool
iCEDisconnected chan bool
reconnectCh chan struct{}
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
reconnectCh := make(chan struct{}, 1)
cm := &ConnMonitor{
signaler: signaler,
iFaceDiscover: iFaceDiscover,
config: config,
relayDisconnected: relayDisconnected,
iCEDisconnected: iCEDisconnected,
reconnectCh: reconnectCh,
}
return cm, reconnectCh
}
func (cm *ConnMonitor) Start(ctx context.Context) {
signalerReady := make(chan struct{}, 1)
go cm.monitorSignalerReady(ctx, signalerReady)
localCandidatesChanged := make(chan struct{}, 1)
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
for {
select {
case changed := <-cm.relayDisconnected:
if !changed {
continue
}
log.Debugf("Relay state changed, triggering reconnect")
cm.triggerReconnect()
case changed := <-cm.iCEDisconnected:
if !changed {
continue
}
log.Debugf("ICE state changed, triggering reconnect")
cm.triggerReconnect()
case <-signalerReady:
log.Debugf("Signaler became ready, triggering reconnect")
cm.triggerReconnect()
case <-localCandidatesChanged:
log.Debugf("Local candidates changed, triggering reconnect")
cm.triggerReconnect()
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
if cm.signaler == nil {
return
}
ticker := time.NewTicker(signalerMonitorPeriod)
defer ticker.Stop()
lastReady := true
for {
select {
case <-ticker.C:
currentReady := cm.signaler.Ready()
if !lastReady && currentReady {
select {
case signalerReady <- struct{}{}:
default:
}
}
lastReady = currentReady
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
ufrag, pwd, err := generateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
log.Warnf("Failed to handle candidate tick: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
log.Debugf("Gathering ICE candidates")
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("wait for gathering: %w", ctx.Err())
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
if changed := cm.updateCandidates(candidates); changed {
select {
case localCandidatesChanged <- struct{}{}:
default:
}
}
return nil
}
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func (cm *ConnMonitor) triggerReconnect() {
select {
case cm.reconnectCh <- struct{}{}:
default:
}
}

View File

@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) return stdnet.NewNet(ifaceBlacklist)
} }

View File

@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
} }

View File

@@ -233,41 +233,16 @@ func (w *WorkerICE) Close() {
} }
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := w.newStdNet() transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
if err != nil { if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err) w.log.Errorf("failed to create pion's stdnet: %s", err)
} }
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: relaySupport,
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
UDPMux: w.config.ICEConfig.UDPMux,
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: w.localUfrag,
LocalPwd: w.localPwd,
}
if w.config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
w.sentExtraSrflx = false w.sentExtraSrflx = false
agent, err := ice.NewAgent(agentConfig)
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create agent: %w", err)
} }
err = agent.OnCandidate(w.onICECandidate) err = agent.OnCandidate(w.onICECandidate)
@@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
} }
} }
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
UDPMux: config.ICEConfig.UDPMux,
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{

View File

@@ -43,7 +43,7 @@ type clientNetwork struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.IWGIface wgInterface *iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
@@ -53,7 +53,7 @@ type clientNetwork struct {
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
@@ -378,7 +378,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))

View File

@@ -48,7 +48,7 @@ type Route struct {
currentPeerKey string currentPeerKey string
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.IWGIface wgInterface *iface.WGIface
resolverAddr string resolverAddr string
} }
@@ -58,7 +58,7 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration, interval time.Duration,
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface iface.IWGIface, wgInterface *iface.WGIface,
resolverAddr string, resolverAddr string,
) *Route { ) *Route {
return &Route{ return &Route{

View File

@@ -52,7 +52,7 @@ type DefaultManager struct {
sysOps *systemops.SysOps sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
relayMgr *relayClient.Manager relayMgr *relayClient.Manager
wgInterface iface.IWGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier.Notifier notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
@@ -64,7 +64,7 @@ func NewManager(
ctx context.Context, ctx context.Context,
pubKey string, pubKey string,
dnsRouteInterval time.Duration, dnsRouteInterval time.Duration,
wgInterface iface.IWGIface, wgInterface *iface.WGIface,
statusRecorder *peer.Status, statusRecorder *peer.Status,
relayMgr *relayClient.Manager, relayMgr *relayClient.Manager,
initialRoutes []*route.Route, initialRoutes []*route.Route,

View File

@@ -11,6 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os") return nil, fmt.Errorf("server route not supported on this os")
} }

View File

@@ -22,11 +22,11 @@ type defaultServerRouter struct {
ctx context.Context ctx context.Context
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
firewall firewall.Manager firewall firewall.Manager
wgInterface iface.IWGIface wgInterface *iface.WGIface
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{ return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),

View File

@@ -23,7 +23,7 @@ const (
) )
// Setup configures sysctl settings for RP filtering and source validation. // Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface.IWGIface) (map[string]int, error) { func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{} keys := map[string]int{}
var result *multierror.Error var result *multierror.Error

View File

@@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface iface.IWGIface wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory // prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update) // (this is used in iOS as all route updates require a full table update)
//nolint //nolint
@@ -30,7 +30,7 @@ type SysOps struct {
notifier *notifier.Notifier notifier *notifier.Notifier
} }
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,

View File

@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values. // If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) { func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr() addr := prefix.Addr()
switch { switch {
case addr.IsLoopback(), case addr.IsLoopback(),

View File

@@ -5,7 +5,6 @@ package ebpf
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
} }
// AddTurnConn add new turn connection for the proxy // AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn) wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{ wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil return packetConn, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1") localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)

View File

@@ -4,8 +4,13 @@ package ebpf
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync"
log "github.com/sirupsen/logrus"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
} }
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
ctxConn, cancel := context.WithCancel(ctx) addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil { if err != nil {
cancel() return fmt.Errorf("add turn conn: %w", err)
return nil, fmt.Errorf("add turn conn: %w", err)
} }
e.remoteConn = remoteConn p.remoteConn = remoteConn
e.cancel = cancel p.ctx, p.cancel = context.WithCancel(ctx)
return addr, err p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard // Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface { type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) AddTurnConn(ctx context.Context, turnConn net.Conn) error
EndpointAddr() *net.UDPAddr
Work()
Pause()
CloseConn() error CloseConn() error
} }

View File

@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn() relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn) err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }

View File

@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies // WGUserSpaceProxy proxies
type WGUserSpaceProxy struct { type WGUserSpaceProxy struct {
localWGListenPort int localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p return p
} }
// AddTurnConn start the proxy with the given remote conn // AddTurnConn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { // The provided Context must be non-nil. If the context expires before
p.ctx, p.cancel = context.WithCancel(ctx) // the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
p.remoteConn = remoteConn // connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
var err error
dialer := net.Dialer{} dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return err
} }
go p.proxyToRemote() p.ctx, p.cancel = context.WithCancel(ctx)
go p.proxyToLocal() p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() { func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err) log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for ctx.Err() == nil {
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n]) _, err = p.remoteConn.Write(buf[:n])
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
} }
// proxyToLocal proxies from the Remote peer to local WireGuard // proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() { // if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err) log.Warnf("error in proxy to local loop: %s", err)
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n]) _, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to write to wg interface conn: %s", err) log.Debugf("failed to write to wg interface conn: %s", err)

View File

@@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{} loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig) _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,6 +20,11 @@ import (
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@@ -36,10 +41,6 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
) )
const ( const (
@@ -76,7 +77,8 @@ type AccountManager interface {
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
@@ -478,12 +480,12 @@ func (a *Account) GetPeerNetworkMap(
} }
nm := &NetworkMap{ nm := &NetworkMap{
Peers: peersToConnect, Peers: peersToConnect,
Network: a.Network.Copy(), Network: a.Network.Copy(),
Routes: routesUpdate, Routes: routesUpdate,
DNSConfig: dnsUpdate, DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers, OfflinePeers: expiredPeers,
FirewallRules: firewallRules, FirewallRules: firewallRules,
RoutesFirewallRules: routesFirewallRules, RoutesFirewallRules: routesFirewallRules,
} }
@@ -843,55 +845,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns true if there are changes in the JWT group membership. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { // newly groups to create and an error if any occurred.
user, ok := a.Users[userID] func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
if !ok {
return false
}
existedGroupsByName := make(map[string]*nbgroup.Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups { for _, group := range groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account // If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false return false, nil, nil, nil
} }
newGroupsToCreate := make([]*nbgroup.Group, 0)
var modified bool var modified bool
for _, name := range groupsToAdd { for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name] group, exists := existedGroupsByName[name]
if !exists { if !exists {
group = &nbgroup.Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: name, AccountID: user.AccountID,
Issued: nbgroup.GroupIssuedJWT, Name: name,
Issued: nbgroup.GroupIssuedJWT,
} }
a.Groups[group.ID] = group newGroupsToCreate = append(newGroupsToCreate, group)
} }
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID) newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true modified = true
} }
} }
for name, id := range jwtGroupsMap { for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) { if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id) newUserAutoGroups = append(newUserAutoGroups, id)
continue continue
} }
modified = true modified = true
} }
user.AutoGroups = newAutoGroups
return modified return modified, newUserAutoGroups, newGroupsToCreate, nil
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
@@ -1262,37 +1263,36 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil return nil
} }
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. // AccountExists checks if an account exists.
// If an accountID is provided, it checks if the account exists and returns it. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. return am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
}
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
// If user does have an account, it returns the user's account ID.
// If the user doesn't have an account, it creates one using the provided domain. // If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created. // Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
if accountID != "" { if userID == "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) return "", status.Errorf(status.NotFound, "no valid userID provided")
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
} }
if userID != "" { accountID, err := am.Store.GetAccountIDByUserID(userID)
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil {
if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
} if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err return "", err
}
return account.Id, nil
} }
return "", err
return account.Id, nil
} }
return accountID, nil
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@@ -1765,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
return nil, err return nil, err
} }
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
} }
@@ -1796,6 +1796,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
}
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
@@ -1803,7 +1807,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
} }
} }
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err return "", "", err
} }
@@ -1812,7 +1816,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
@@ -1823,69 +1827,136 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.JWTGroupsClaimName == "" { if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil return nil
} }
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups)) unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
copy(oldGroups, user.AutoGroups) defer func() {
if unlockPeer != nil {
unlockPeer()
}
}()
// Update the account if group membership changes var addNewGroups []string
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { var removeOldGroups []string
addNewGroups := difference(user.AutoGroups, oldGroups) var hasChanges bool
removeOldGroups := difference(oldGroups, user.AutoGroups) var user *User
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if settings.GroupsPropagationEnabled { user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) if err != nil {
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) return fmt.Errorf("error getting user: %w", err)
account.Network.IncSerial()
} }
if err := am.Store.SaveAccount(ctx, account); err != nil { groups, err := am.Store.GetAccountGroups(ctx, accountID)
log.WithContext(ctx).Errorf("failed to save account: %v", err) if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
hasChanges = changed
// skip update if no changes
if !changed {
return nil return nil
} }
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) groups, err = transaction.GetAccountGroups(ctx, accountID)
am.updateAccountPeers(ctx, account) if err != nil {
} return fmt.Errorf("error getting account groups: %w", err)
}
for _, g := range addNewGroups { groupsMap := make(map[string]*nbgroup.Group, len(groups))
if group := account.GetGroup(g); group != nil { for _, group := range groups {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, groupsMap[group.ID] = group
map[string]any{ }
"group": group.Name,
"group_id": group.ID, peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
"is_service_user": user.IsServiceUser, if err != nil {
"user_name": user.ServiceUserName}) return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
unlockPeer()
unlockPeer = nil
for _, g := range removeOldGroups { return nil
if group := account.GetGroup(g); group != nil { })
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, if err != nil {
map[string]any{ return err
"group": group.Name, }
"group_id": group.ID,
"is_service_user": user.IsServiceUser, if !hasChanges {
"user_name": user.ServiceUserName}) return nil
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
} }
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
} }
} }
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
@@ -1916,7 +1987,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil { if err != nil {
@@ -2229,7 +2310,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
routes := make(map[route.ID]*route.Route) routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
owner := NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := DNSSettings{ dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0), DisabledManagementGroups: make([]string, 0),
} }
@@ -2297,12 +2382,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups. // separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups, // Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs. // where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0) newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID jwtAutoGroups := make(map[string]string) // map of group name to group ID
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}
for _, id := range autoGroups { for _, id := range autoGroups {
if group, ok := allGroups[id]; ok { if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id jwtAutoGroups[group.Name] = id
} else { } else {
@@ -2310,5 +2400,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
} }
} }
} }
return newAutoGroups, jwtAutoGroups return newAutoGroups, jwtAutoGroups
} }

View File

@@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
initAccount := newAccountWithId(context.Background(), "", userId, domain) _ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err = manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
@@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
} }
} }
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return return
} }
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
if err != nil { assert.NoError(t, err)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) assert.True(t, exists, "expected to get existing account after creation using userid")
}
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") _, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user ID is empty")
} }
} }
@@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
@@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
@@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
@@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
} }
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account // create a new account
account := &Account{ account := &Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
}, },
Settings: &Settings{GroupsPropagationEnabled: true}, Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{ Users: map[string]*User{
"user1": {Id: "user1"}, "user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2"}, "user2": {Id: "user2", AccountID: "accountID"},
}, },
} }
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
t.Run("jwt match existing api group", func(t *testing.T) { t.Run("jwt match existing api group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("add jwt group", func(t *testing.T) { t.Run("add jwt group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("existed group not update", func(t *testing.T) { t.Run("existed group not update", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group2"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "groups count should not be changed") Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("add new group", func(t *testing.T) { t.Run("add new group", func(t *testing.T) {
updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user2",
assert.Len(t, account.Groups, 3, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") Raw: jwt.MapClaims{"groups": []interface{}{}},
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
}) })
} }

View File

@@ -27,7 +27,8 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -58,7 +59,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -194,14 +195,22 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface // AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil { if am.AccountExistsFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) return am.AccountExistsFunc(ctx, accountID)
}
return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
}
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
} }
return "", status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountIDByUserOrAccountID is not implemented", "method GetAccountIDByUserID is not implemented",
) )
} }
@@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }

View File

@@ -693,6 +693,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updateRemotePeers := false updateRemotePeers := false
if login.UserID != "" { if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
}
changed, err := am.handleUserPeer(ctx, peer, settings) changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err

View File

@@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
if err != nil { if err != nil {
return nil, err return nil, err
} }
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
if err != nil {
conns = runtime.NumCPU()
}
sql.SetMaxOpenConns(conns)
log.Infof("Set max open db connections to %d", conns)
if err := migrate(ctx, db); err != nil { if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err) return nil, fmt.Errorf("migrate: %w", err)
@@ -378,15 +385,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error Create(&usersToSave).Error
} }
// SaveGroups saves the given list of groups to the database. // SaveUser saves the given user to the database.
// It updates existing groups if a conflict occurs. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
groupsToSave := make([]nbgroup.Group, 0, len(groups)) if result.Error != nil {
for _, group := range groups { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
} }
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
} }
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -420,7 +438,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
} }
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.NewGetAccountFromStoreError(result.Error)
} }
return accountID, nil return accountID, nil
@@ -433,7 +451,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return nil, status.NewSetupKeyNotFoundError() return nil, status.NewSetupKeyNotFoundError(result.Error)
} }
if key.AccountID == "" { if key.AccountID == "" {
@@ -451,7 +469,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.NewGetAccountFromStoreError(result.Error)
} }
return token.ID, nil return token.ID, nil
@@ -465,7 +483,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.NewGetAccountFromStoreError(result.Error)
} }
if token.UserID == "" { if token.UserID == "" {
@@ -549,7 +567,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.NewGetAccountFromStoreError(result.Error)
} }
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
@@ -612,7 +630,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.NewGetAccountFromStoreError(result.Error)
} }
if user.AccountID == "" { if user.AccountID == "" {
@@ -629,7 +647,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.NewGetAccountFromStoreError(result.Error)
} }
if peer.AccountID == "" { if peer.AccountID == "" {
@@ -647,7 +665,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.NewGetAccountFromStoreError(result.Error)
} }
if peer.AccountID == "" { if peer.AccountID == "" {
@@ -665,7 +683,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.NewGetAccountFromStoreError(result.Error)
} }
return accountID, nil return accountID, nil
@@ -678,7 +696,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.NewGetAccountFromStoreError(result.Error)
} }
return accountID, nil return accountID, nil
@@ -691,7 +709,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return "", status.NewSetupKeyNotFoundError() return "", status.NewSetupKeyNotFoundError(result.Error)
} }
if accountID == "" { if accountID == "" {
@@ -712,7 +730,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account") return nil, status.Errorf(status.NotFound, "no peers found for the account")
} }
return nil, status.Errorf(status.Internal, "issue getting IPs from store") return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
} }
// Convert the JSON strings to net.IP objects // Convert the JSON strings to net.IP objects
@@ -740,7 +758,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
return nil, status.Errorf(status.NotFound, "no peers found for the account") return nil, status.Errorf(status.NotFound, "no peers found for the account")
} }
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store") return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
} }
return labels, nil return labels, nil
@@ -753,7 +771,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
return nil, status.Errorf(status.Internal, "issue getting network from store") return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
} }
return accountNetwork.Network, nil return accountNetwork.Network, nil
} }
@@ -765,7 +783,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found") return nil, status.Errorf(status.NotFound, "peer not found")
} }
return nil, status.Errorf(status.Internal, "issue getting peer from store") return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
} }
return &peer, nil return &peer, nil
@@ -777,7 +795,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
return nil, status.Errorf(status.Internal, "issue getting settings from store") return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
} }
return accountSettings.Settings, nil return accountSettings.Settings, nil
} }
@@ -945,7 +963,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
return nil, status.NewSetupKeyNotFoundError() return nil, status.NewSetupKeyNotFoundError(result.Error)
} }
return &setupKey, nil return &setupKey, nil
} }
@@ -977,7 +995,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account") return status.Errorf(status.NotFound, "group 'All' not found for account")
} }
return status.Errorf(status.Internal, "issue finding group 'All'") return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
} }
for _, existingPeerID := range group.Peers { for _, existingPeerID := range group.Peers {
@@ -989,7 +1007,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'") return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
} }
return nil return nil
@@ -1003,7 +1021,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account") return status.Errorf(status.NotFound, "group not found for account")
} }
return status.Errorf(status.Internal, "issue finding group") return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
} }
for _, existingPeerID := range group.Peers { for _, existingPeerID := range group.Peers {
@@ -1015,15 +1033,20 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId) group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group") return status.Errorf(status.Internal, "issue updating group: %s", err)
} }
return nil return nil
} }
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account") return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
return nil return nil
@@ -1032,7 +1055,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count") return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
} }
return nil return nil
} }
@@ -1127,6 +1150,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes) assert.Equal(t, 2, setupKey.UsedTimes)
} }
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
assert.NoError(t, err)
}

View File

@@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error {
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error { func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found") return Errorf(NotFound, "setup key not found: %s", err)
}
func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err)
} }
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store

View File

@@ -60,6 +60,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -68,7 +69,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -82,6 +84,7 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error

View File

@@ -8,14 +8,14 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
} }
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
// skip removing peers from group All
if group.Name == "All" {
return
}
updatedPeers := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
}
}
group.Peers = updatedPeers
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData { for _, user := range userData {
if user.ID == userID { if user.ID == userID {

View File

@@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") acc, err := am.Store.GetAccount(context.Background(), account.Id)
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {

View File

@@ -1,11 +1,15 @@
package util package util
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"text/template"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
return res, nil return res, nil
} }
// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution
func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) {
envVars := getEnvMap()
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer f.Close()
bs, err := io.ReadAll(f)
if err != nil {
return nil, err
}
t, err := template.New("").Parse(string(bs))
if err != nil {
return nil, fmt.Errorf("error parsing template: %v", err)
}
var output bytes.Buffer
// Execute the template, substituting environment variables
err = t.Execute(&output, envVars)
if err != nil {
return nil, fmt.Errorf("error executing template: %v", err)
}
err = json.Unmarshal(output.Bytes(), &res)
if err != nil {
return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err)
}
return res, nil
}
// getEnvMap Convert the output of os.Environ() to a map
func getEnvMap() map[string]string {
envMap := make(map[string]string)
for _, env := range os.Environ() {
parts := strings.SplitN(env, "=", 2)
if len(parts) == 2 {
envMap[parts[0]] = parts[1]
}
}
return envMap
}
// CopyFileContents copies contents of the given src file to the dst file // CopyFileContents copies contents of the given src file to the dst file
func CopyFileContents(src, dst string) (err error) { func CopyFileContents(src, dst string) (err error) {
in, err := os.Open(src) in, err := os.Open(src)

126
util/file_suite_test.go Normal file
View File

@@ -0,0 +1,126 @@
package util_test
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/netbirdio/netbird/util"
)
var _ = Describe("Client", func() {
var (
tmpDir string
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})
Describe("Config", func() {
Context("in JSON format", func() {
It("should be written and read successfully", func() {
m := make(map[string]string)
m["key1"] = "value1"
m["key2"] = "value2"
arr := []string{"value1", "value2"}
written := &TestConfig{
SomeMap: m,
SomeArray: arr,
SomeField: 99,
}
err := util.WriteJson(tmpDir+"/testconfig.json", written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
})
})
})
Describe("Copying file contents", func() {
Context("from one file to another", func() {
It("should be successful", func() {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
}
cfgFile := "test_cfg.json"
defer os.Remove(cfgFile)
err := util.WriteJson(cfgFile, written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(cfgFile, &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
})
})
})
})

View File

@@ -1,126 +1,198 @@
package util_test package util
import ( import (
"crypto/md5"
"encoding/hex"
"io"
"os" "os"
"reflect"
. "github.com/onsi/ginkgo" "strings"
. "github.com/onsi/gomega" "testing"
"github.com/netbirdio/netbird/util"
) )
var _ = Describe("Client", func() { func TestReadJsonWithEnvSub(t *testing.T) {
type Config struct {
var ( CertFile string `json:"CertFile"`
tmpDir string Credentials string `json:"Credentials"`
) NestedOption struct {
URL string `json:"URL"`
type TestConfig struct { } `json:"NestedOption"`
SomeMap map[string]string
SomeArray []string
SomeField int
} }
BeforeEach(func() { type testCase struct {
var err error name string
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") envVars map[string]string
Expect(err).NotTo(HaveOccurred()) jsonTemplate string
}) expectedResult Config
expectError bool
errorContains string
}
AfterEach(func() { tests := []testCase{
err := os.RemoveAll(tmpDir) {
Expect(err).NotTo(HaveOccurred()) name: "All environment variables set",
}) envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/env_cert.crt",
Credentials: "env_credentials",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "https://env.testing.com",
},
},
expectError: false,
},
{
name: "Missing environment variable",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
// "URL" is intentionally missing
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/env_cert.crt",
Credentials: "env_credentials",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "<no value>",
},
},
expectError: false,
},
{
name: "Invalid JSON template",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`, // Note the missing closing brace in "{{ .CREDENTIALS }"
expectedResult: Config{},
expectError: true,
errorContains: "unexpected \"}\" in operand",
},
{
name: "No substitutions",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "/etc/certs/cert.crt",
"Credentials": "admnlknflkdasdf",
"NestedOption" : {
"URL": "https://testing.com"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/cert.crt",
Credentials: "admnlknflkdasdf",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "https://testing.com",
},
},
expectError: false,
},
{
name: "Should fail when Invalid characters in variables",
envVars: map[string]string{
"CERT_FILE": `"/etc/certs/"cert".crt"`,
"CREDENTIALS": `env_credentia{ls}`,
"URL": `https://env.testing.com?param={{value}}`,
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: `"/etc/certs/"cert".crt"`,
Credentials: `env_credentia{ls}`,
NestedOption: struct {
URL string `json:"URL"`
}{
URL: `https://env.testing.com?param={{value}}`,
},
},
expectError: true,
},
}
Describe("Config", func() { for _, tc := range tests {
Context("in JSON format", func() { tc := tc
It("should be written and read successfully", func() { t.Run(tc.name, func(t *testing.T) {
for key, value := range tc.envVars {
t.Setenv(key, value)
}
m := make(map[string]string) tempFile, err := os.CreateTemp("", "config*.json")
m["key1"] = "value1" if err != nil {
m["key2"] = "value2" t.Fatalf("Failed to create temp file: %v", err)
}
arr := []string{"value1", "value2"} defer func() {
err = os.Remove(tempFile.Name())
written := &TestConfig{ if err != nil {
SomeMap: m, t.Logf("Failed to remove temp file: %v", err)
SomeArray: arr,
SomeField: 99,
} }
}()
err := util.WriteJson(tmpDir+"/testconfig.json", written) _, err = tempFile.WriteString(tc.jsonTemplate)
Expect(err).NotTo(HaveOccurred()) if err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
err = tempFile.Close()
if err != nil {
t.Fatalf("Failed to close temp file: %v", err)
}
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) var result Config
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
}) _, err = ReadJsonWithEnvSub(tempFile.Name(), &result)
})
})
Describe("Copying file contents", func() { if tc.expectError {
Context("from one file to another", func() { if err == nil {
It("should be successful", func() { t.Fatalf("Expected error but got none")
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
} }
cfgFile := "test_cfg.json" if !strings.Contains(err.Error(), tc.errorContains) {
defer os.Remove(cfgFile) t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err)
}
err := util.WriteJson(cfgFile, written) } else {
Expect(err).NotTo(HaveOccurred()) if err != nil {
t.Fatalf("ReadJsonWithEnvSub failed: %v", err)
read, err := util.ReadJson(cfgFile, &TestConfig{}) }
Expect(err).NotTo(HaveOccurred()) if !reflect.DeepEqual(result, tc.expectedResult) {
Expect(read).NotTo(BeNil()) t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult)
}) }
}
}) })
}) }
}) }