mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
Compare commits
17 Commits
v0.29.2
...
debug-0.29
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30995707a | ||
|
|
5c8bfb7cea | ||
|
|
fc4b37f7bc | ||
|
|
6f0fd1d1b3 | ||
|
|
28cbb4b70f | ||
|
|
1104c9c048 | ||
|
|
5bc601111d | ||
|
|
b74951f29e | ||
|
|
97e10e440c | ||
|
|
6c50b0c84b | ||
|
|
730dd1733e | ||
|
|
82739e2832 | ||
|
|
fa7767e612 | ||
|
|
f1171198de | ||
|
|
9e041b7f82 | ||
|
|
b4c8cf0a67 | ||
|
|
1ef51a4ffa |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
|||||||
@@ -117,6 +117,11 @@ type Config struct {
|
|||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
func ReadConfig(configPath string) (*Config, error) {
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
if configFileIsExists(configPath) {
|
if configFileIsExists(configPath) {
|
||||||
|
err := util.EnforcePermission(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
config := &Config{}
|
config := &Config{}
|
||||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = WriteOutConfig(input.ConfigPath, cfg)
|
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||||
return cfg, err
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||||
input.PreSharedKey = nil
|
input.PreSharedKey = nil
|
||||||
}
|
}
|
||||||
|
err := util.EnforcePermission(input.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
|
}
|
||||||
return update(input)
|
return update(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
|
runningChanOpen := true
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
if c.isContextCancelled() {
|
if c.isContextCancelled() {
|
||||||
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
|
if c.engine != nil && c.engine.ctx.Err() != nil {
|
||||||
|
log.Info("Stopping Netbird Engine")
|
||||||
|
if err := c.engine.Stop(); err != nil {
|
||||||
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||||
|
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
|
|||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil {
|
if runningChan != nil && runningChanOpen {
|
||||||
runningChan <- nil
|
runningChan <- nil
|
||||||
close(runningChan)
|
close(runningChan)
|
||||||
|
runningChanOpen = false
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|||||||
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
if e.dnsServer != nil {
|
e.stopDNSServer()
|
||||||
e.dnsServer.Stop()
|
|
||||||
e.dnsServer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop()
|
e.routeManager.Stop()
|
||||||
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) restartEngine() {
|
func (e *Engine) restartEngine() {
|
||||||
|
log.Info("restarting engine")
|
||||||
|
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||||
|
|
||||||
if err := e.Stop(); err != nil {
|
if err := e.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.Start(); err != nil {
|
|
||||||
log.Errorf("Failed to start engine: %v", err)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
}
|
log.Infof("cancelling client, engine will be recreated")
|
||||||
|
e.clientCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
if debounceTimer != nil {
|
if debounceTimer != nil {
|
||||||
|
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
||||||
debounceTimer.Stop()
|
debounceTimer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
log.Infof("Network monitor detected network change, restarting engine")
|
log.Infof("Network monitor: detected network change, restarting engine")
|
||||||
e.restartEngine()
|
e.restartEngine()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
|||||||
return false, netip.Prefix{}, nil
|
return false, netip.Prefix{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopDNSServer() {
|
||||||
|
err := fmt.Errorf("DNS server stopped")
|
||||||
|
nsGroupStates := e.statusRecorder.GetDNSStates()
|
||||||
|
for i := range nsGroupStates {
|
||||||
|
nsGroupStates[i].Enabled = false
|
||||||
|
nsGroupStates[i].Error = err
|
||||||
|
}
|
||||||
|
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||||
|
if e.dnsServer != nil {
|
||||||
|
e.dnsServer.Stop()
|
||||||
|
e.dnsServer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
|
|||||||
@@ -89,8 +89,8 @@ type Conn struct {
|
|||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
onDisconnected func(remotePeer string, wgIP string)
|
onDisconnected func(remotePeer string, wgIP string)
|
||||||
|
|
||||||
statusRelay ConnStatus
|
statusRelay *AtomicConnStatus
|
||||||
statusICE ConnStatus
|
statusICE *AtomicConnStatus
|
||||||
currentConnPriority ConnPriority
|
currentConnPriority ConnPriority
|
||||||
opened bool // this flag is used to prevent close in case of not opened connection
|
opened bool // this flag is used to prevent close in case of not opened connection
|
||||||
|
|
||||||
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
allowedIPsIP: allowedIPsIP.String(),
|
allowedIPsIP: allowedIPsIP.String(),
|
||||||
statusRelay: StatusDisconnected,
|
statusRelay: NewAtomicConnStatus(),
|
||||||
statusICE: StatusDisconnected,
|
statusICE: NewAtomicConnStatus(),
|
||||||
iCEDisconnected: make(chan bool, 1),
|
iCEDisconnected: make(chan bool, 1),
|
||||||
relayDisconnected: make(chan bool, 1),
|
relayDisconnected: make(chan bool, 1),
|
||||||
}
|
}
|
||||||
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
|
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
|
||||||
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if conn.statusICE == StatusDisconnected {
|
if conn.statusICE.Get() == StatusDisconnected {
|
||||||
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
conn.statusICE = StatusConnected
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
|
||||||
defer conn.updateIceState(iceConnInfo)
|
defer conn.updateIceState(iceConnInfo)
|
||||||
|
|
||||||
@@ -484,16 +484,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
|
||||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
err := conn.configureWGEndpoint(conn.endpointRelay)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||||
conn.statusICE = newState
|
conn.statusICE.Set(newState)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn.iCEDisconnected <- changed:
|
case conn.iCEDisconnected <- changed:
|
||||||
@@ -518,11 +518,14 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection is ready to use")
|
||||||
conn.statusRelay = StatusConnected
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||||
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
||||||
@@ -530,6 +533,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
conn.endpointRelay = endpointUdpAddr
|
conn.endpointRelay = endpointUdpAddr
|
||||||
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
|
|
||||||
if conn.currentConnPriority > connPriorityRelay {
|
if conn.currentConnPriority > connPriorityRelay {
|
||||||
if conn.statusICE == StatusConnected {
|
if conn.statusICE.Get() == StatusConnected {
|
||||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -551,7 +555,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
@@ -560,6 +563,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
conn.statusRelay = StatusDisconnected
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn.relayDisconnected <- changed:
|
case conn.relayDisconnected <- changed:
|
||||||
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) setStatusToDisconnected() {
|
func (conn *Conn) setStatusToDisconnected() {
|
||||||
conn.statusRelay = StatusDisconnected
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
conn.statusICE = StatusDisconnected
|
conn.statusICE.Set(StatusDisconnected)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isRelayed() bool {
|
func (conn *Conn) isRelayed() bool {
|
||||||
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
|
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) evalStatus() ConnStatus {
|
func (conn *Conn) evalStatus() ConnStatus {
|
||||||
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
|
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
||||||
return StatusConnected
|
return StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
|
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
||||||
return StatusConnecting
|
return StatusConnecting
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
|
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
if conn.statusRelay != StatusConnected {
|
if conn.statusRelay.Get() != StatusConnected {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -775,9 +779,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
|
|||||||
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
err = wgProxy.CloseConn()
|
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||||
if err != nil {
|
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||||
conn.log.Warnf("failed to close turn proxy connection: %v", err)
|
|
||||||
}
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import log "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// StatusConnected indicate the peer is in connected state
|
// StatusConnected indicate the peer is in connected state
|
||||||
@@ -12,7 +16,34 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConnStatus describe the status of a peer's connection
|
// ConnStatus describe the status of a peer's connection
|
||||||
type ConnStatus int
|
type ConnStatus int32
|
||||||
|
|
||||||
|
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||||
|
type AtomicConnStatus struct {
|
||||||
|
status atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||||
|
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||||
|
acs := &AtomicConnStatus{}
|
||||||
|
acs.Set(StatusDisconnected)
|
||||||
|
return acs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current connection status
|
||||||
|
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||||
|
return ConnStatus(acs.status.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set updates the connection status
|
||||||
|
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||||
|
acs.status.Store(int32(status))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the current status
|
||||||
|
func (acs *AtomicConnStatus) String() string {
|
||||||
|
return acs.Get().String()
|
||||||
|
}
|
||||||
|
|
||||||
func (s ConnStatus) String() string {
|
func (s ConnStatus) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
|
|||||||
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
|
|||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
t.Run(table.name, func(t *testing.T) {
|
t.Run(table.name, func(t *testing.T) {
|
||||||
conn.statusICE = table.statusIce
|
si := NewAtomicConnStatus()
|
||||||
conn.statusRelay = table.statusRelay
|
si.Set(table.statusIce)
|
||||||
|
conn.statusICE = si
|
||||||
|
|
||||||
|
sr := NewAtomicConnStatus()
|
||||||
|
sr.Set(table.statusRelay)
|
||||||
|
conn.statusRelay = sr
|
||||||
|
|
||||||
got := conn.Status()
|
got := conn.Status()
|
||||||
assert.Equal(t, got, table.want, "they should be equal")
|
assert.Equal(t, got, table.want, "they should be equal")
|
||||||
|
|||||||
@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
return RosenpassState{
|
return RosenpassState{
|
||||||
d.rosenpassEnabled,
|
d.rosenpassEnabled,
|
||||||
d.rosenpassPermissive,
|
d.rosenpassPermissive,
|
||||||
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetManagementState() ManagementState {
|
func (d *Status) GetManagementState() ManagementState {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
return ManagementState{
|
return ManagementState{
|
||||||
d.mgmAddress,
|
d.mgmAddress,
|
||||||
d.managementState,
|
d.managementState,
|
||||||
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetSignalState() SignalState {
|
func (d *Status) GetSignalState() SignalState {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
return SignalState{
|
return SignalState{
|
||||||
d.signalAddress,
|
d.signalAddress,
|
||||||
d.signalState,
|
d.signalState,
|
||||||
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
|
|||||||
|
|
||||||
// GetRelayStates returns the stun/turn/permanent relay states
|
// GetRelayStates returns the stun/turn/permanent relay states
|
||||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
if d.relayMgr == nil {
|
if d.relayMgr == nil {
|
||||||
return d.relayStates
|
return d.relayStates
|
||||||
}
|
}
|
||||||
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
return d.nsGroupStates
|
return d.nsGroupStates
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
|||||||
|
|
||||||
// GetFullStatus gets full status
|
// GetFullStatus gets full status
|
||||||
func (d *Status) GetFullStatus() FullStatus {
|
func (d *Status) GetFullStatus() FullStatus {
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
fullStatus := FullStatus{
|
fullStatus := FullStatus{
|
||||||
ManagementState: d.GetManagementState(),
|
ManagementState: d.GetManagementState(),
|
||||||
SignalState: d.GetSignalState(),
|
SignalState: d.GetSignalState(),
|
||||||
LocalPeerState: d.localPeer,
|
|
||||||
Relays: d.GetRelayStates(),
|
Relays: d.GetRelayStates(),
|
||||||
RosenpassState: d.GetRosenpassState(),
|
RosenpassState: d.GetRosenpassState(),
|
||||||
NSGroupStates: d.GetDNSStates(),
|
NSGroupStates: d.GetDNSStates(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
fullStatus.LocalPeerState = d.localPeer
|
||||||
|
|
||||||
for _, status := range d.peers {
|
for _, status := range d.peers {
|
||||||
fullStatus.Peers = append(fullStatus.Peers, status)
|
fullStatus.Peers = append(fullStatus.Peers, status)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(ctx)
|
ctx, ctxCancel := context.WithCancel(ctx)
|
||||||
w.wgStateCheck(ctx)
|
|
||||||
w.ctxWgWatch = ctx
|
w.ctxWgWatch = ctx
|
||||||
w.ctxCancelWgWatch = ctxCancel
|
w.ctxCancelWgWatch = ctxCancel
|
||||||
|
|
||||||
|
w.wgStateCheck(ctx, ctxCancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) DisableWgWatcher() {
|
func (w *WorkerRelay) DisableWgWatcher() {
|
||||||
@@ -158,21 +158,22 @@ func (w *WorkerRelay) CloseConn() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
|
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
|
||||||
|
w.log.Debugf("WireGuard watcher started")
|
||||||
lastHandshake, err := w.wgState()
|
lastHandshake, err := w.wgState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to read wg stats: %v", err)
|
w.log.Warnf("failed to read wg stats: %v", err)
|
||||||
lastHandshake = time.Time{}
|
lastHandshake = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
go func(lastHandshake time.Time) {
|
go func(lastHandshake time.Time) {
|
||||||
timer := time.NewTimer(wgHandshakeOvertime)
|
timer := time.NewTimer(wgHandshakeOvertime)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
|
defer ctxCancel()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
|
|
||||||
handshake, err := w.wgState()
|
handshake, err := w.wgState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to read wg stats: %v", err)
|
w.log.Errorf("failed to read wg stats: %v", err)
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn start the proxy with the given remote conn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
||||||
p.remoteConn = turnConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
@@ -54,6 +54,14 @@ func (p *WGUserSpaceProxy) CloseConn() error {
|
|||||||
if p.localConn == nil {
|
if p.localConn == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.remoteConn.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close remote conn: %s", err)
|
||||||
|
}
|
||||||
return p.localConn.Close()
|
return p.localConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +73,8 @@ func (p *WGUserSpaceProxy) Free() error {
|
|||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||||
|
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -93,7 +103,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||||
|
defer p.cancel()
|
||||||
|
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -103,7 +114,6 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
|||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
p.cancel()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("failed to read from remote conn: %s", err)
|
log.Errorf("failed to read from remote conn: %s", err)
|
||||||
|
|||||||
@@ -1,13 +1,41 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// only if env is not set
|
||||||
|
if os.Getenv("NB_LOG_LEVEL") == "" {
|
||||||
|
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
|
||||||
|
log.Errorf("Failed setting log-level: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
|
||||||
|
log.Errorf("Failed setting log-size: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
|
||||||
|
log.Errorf("Failed setting panic log path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go startPprofServer()
|
||||||
|
|
||||||
if err := cmd.Execute(); err != nil {
|
if err := cmd.Execute(); err != nil {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startPprofServer() {
|
||||||
|
pprofAddr := "localhost:6969"
|
||||||
|
log.Infof("Starting pprof debugging server on %s", pprofAddr)
|
||||||
|
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||||
|
log.Infof("pprof server failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -110,6 +110,35 @@ func (s *Server) Start() error {
|
|||||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||||
s.actCancel = cancel
|
s.actCancel = cancel
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
|
||||||
|
log.Infof("Error getting status: %v", err)
|
||||||
|
} else if statusResp.FullStatus != nil {
|
||||||
|
log.Infof("Status --------")
|
||||||
|
for _, peer := range statusResp.FullStatus.Peers {
|
||||||
|
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
|
||||||
|
peer.Fqdn,
|
||||||
|
peer.IP,
|
||||||
|
peer.PubKey,
|
||||||
|
peer.ConnStatus,
|
||||||
|
peer.Relayed,
|
||||||
|
peer.RelayAddress,
|
||||||
|
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
||||||
// on failure we return error to retry
|
// on failure we return error to retry
|
||||||
config, err := internal.UpdateConfig(s.latestConfigInput)
|
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func Execute() error {
|
|||||||
func init() {
|
func init() {
|
||||||
stopCh = make(chan int)
|
stopCh = make(chan int)
|
||||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||||
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||||
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||||
|
|||||||
@@ -263,6 +263,11 @@ type AccountSettings struct {
|
|||||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subclass used in gorm to only load network and not whole account
|
||||||
|
type AccountNetwork struct {
|
||||||
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermissions struct {
|
||||||
DashboardView string `json:"dashboard_view"`
|
DashboardView string `json:"dashboard_view"`
|
||||||
}
|
}
|
||||||
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
|||||||
return grps
|
return grps
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getUserGroups(userID string) ([]string, error) {
|
|
||||||
user, err := a.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return user.AutoGroups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||||
peerGroups := a.getPeerGroups(peerID)
|
peerGroups := a.getPeerGroups(peerID)
|
||||||
enabled := true
|
enabled := true
|
||||||
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
|
|||||||
return groupList
|
return groupList
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
|
|
||||||
key, err := a.FindSetupKey(setupKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key.AutoGroups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getTakenIPs() []net.IP {
|
func (a *Account) getTakenIPs() []net.IP {
|
||||||
var takenIps []net.IP
|
var takenIps []net.IP
|
||||||
for _, existingPeer := range a.Peers {
|
for _, existingPeer := range a.Peers {
|
||||||
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
|
||||||
|
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
labelMap := ConvertSliceToMap(existingLabels)
|
||||||
|
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newLabel == "" {
|
||||||
|
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newLabel, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exist
|
// addAllGroup to account object if it doesn't exist
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||||
@@ -115,12 +114,23 @@ func pkcs5Padding(ciphertext []byte) []byte {
|
|||||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||||
return append(ciphertext, padText...)
|
return append(ciphertext, padText...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||||
srcLen := len(src)
|
srcLen := len(src)
|
||||||
paddingLen := int(src[srcLen-1])
|
if srcLen == 0 {
|
||||||
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
|
return nil, errors.New("input data is empty")
|
||||||
return nil, fmt.Errorf("padding size error")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
paddingLen := int(src[srcLen-1])
|
||||||
|
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
|
||||||
|
return nil, errors.New("invalid padding size")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that all padding bytes are the same
|
||||||
|
for i := 0; i < paddingLen; i++ {
|
||||||
|
if src[srcLen-1-i] != byte(paddingLen) {
|
||||||
|
return nil, errors.New("invalid padding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return src[:srcLen-paddingLen], nil
|
return src[:srcLen-paddingLen], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
|
|||||||
t.Fatalf("incorrect decryption, the result is: %s", res)
|
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecrypt(t *testing.T) {
|
||||||
|
// Generate a key for encryption/decryption
|
||||||
|
key, err := GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the FieldEncrypt with the generated key
|
||||||
|
ec, err := NewFieldEncrypt(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create FieldEncrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty String",
|
||||||
|
input: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Short String",
|
||||||
|
input: "Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "String with Spaces",
|
||||||
|
input: "Hello, World!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Long String",
|
||||||
|
input: "The quick brown fox jumps over the lazy dog.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unicode Characters",
|
||||||
|
input: "こんにちは世界",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Special Characters",
|
||||||
|
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Numeric String",
|
||||||
|
input: "1234567890",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Repeated Characters",
|
||||||
|
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multi-block String",
|
||||||
|
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-ASCII and ASCII Mix",
|
||||||
|
input: "Hello 世界 123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name+" - Legacy", func(t *testing.T) {
|
||||||
|
// Legacy Encryption
|
||||||
|
encryptedLegacy := ec.LegacyEncrypt(tc.input)
|
||||||
|
if encryptedLegacy == "" {
|
||||||
|
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy Decryption
|
||||||
|
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the decrypted value matches the original input
|
||||||
|
if decryptedLegacy != tc.input {
|
||||||
|
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(tc.name+" - New", func(t *testing.T) {
|
||||||
|
// New Encryption
|
||||||
|
encryptedNew, err := ec.Encrypt(tc.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
|
||||||
|
}
|
||||||
|
if encryptedNew == "" {
|
||||||
|
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New Decryption
|
||||||
|
decryptedNew, err := ec.Decrypt(encryptedNew)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the decrypted value matches the original input
|
||||||
|
if decryptedNew != tc.input {
|
||||||
|
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCS5UnPadding(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected []byte
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Padding",
|
||||||
|
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
|
||||||
|
expected: []byte("Hello, World!"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty Input",
|
||||||
|
input: []byte{},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Padding Length Zero",
|
||||||
|
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Padding Length Exceeds Block Size",
|
||||||
|
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Padding Length Exceeds Input Length",
|
||||||
|
input: []byte{5, 5, 5},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Padding Bytes",
|
||||||
|
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Single Byte Padding",
|
||||||
|
input: append([]byte("Hello, World!"), byte(1)),
|
||||||
|
expected: []byte("Hello, World!"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Mixed Padding Bytes",
|
||||||
|
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Full Block Padding",
|
||||||
|
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
|
||||||
|
expected: []byte("Hello, World!"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-Padding Byte at End",
|
||||||
|
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Padding with Different Text Length",
|
||||||
|
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
|
||||||
|
expected: []byte("Test"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Padding Length Equal to Input Length",
|
||||||
|
input: bytes.Repeat([]byte{8}, 8),
|
||||||
|
expected: []byte{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Padding Length Zero (Again)",
|
||||||
|
input: append([]byte("Test"), byte(0)),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Padding Length Greater Than Input",
|
||||||
|
input: []byte{10},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Input Length Not Multiple of Block Size",
|
||||||
|
input: append([]byte("Invalid Length"), byte(1)),
|
||||||
|
expected: []byte("Invalid Length"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Padding with Non-ASCII Characters",
|
||||||
|
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
|
||||||
|
expected: []byte("こんにちは"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := pkcs5UnPadding(tt.input)
|
||||||
|
if tt.expectError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error but got nil")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Did not expect error but got: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(result, tt.expected) {
|
||||||
|
t.Errorf("Expected output %v, got %v", tt.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockStore struct {
|
type MockStore struct {
|
||||||
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
|
|||||||
return s.account, nil
|
return s.account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("account not found")
|
return nil, status.NewPeerNotFoundError(peerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MocAccountManager struct {
|
type MocAccountManager struct {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -46,6 +48,158 @@ type FileStore struct {
|
|||||||
metrics telemetry.AppMetrics `json:"-"`
|
metrics telemetry.AppMetrics `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
|
||||||
|
return f(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
|
||||||
|
if !ok {
|
||||||
|
return status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.SetupKeys[setupKeyID].UsedTimes++
|
||||||
|
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil || allGroup == nil {
|
||||||
|
return errors.New("all group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroup.Peers = append(allGroup.Peers, peerID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, ok := s.Accounts[peer.AccountID]
|
||||||
|
if !ok {
|
||||||
|
return status.NewAccountNotFoundError(peer.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Peers[peer.ID] = peer
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, ok := s.Accounts[accountId]
|
||||||
|
if !ok {
|
||||||
|
return status.NewAccountNotFoundError(accountId)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Network.Serial++
|
||||||
|
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, ok := account.SetupKeys[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var takenIps []net.IP
|
||||||
|
for _, existingPeer := range account.Peers {
|
||||||
|
takenIps = append(takenIps, existingPeer.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return takenIps, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
existingLabels := []string{}
|
||||||
|
for _, peer := range account.Peers {
|
||||||
|
if peer.DNSLabel != "" {
|
||||||
|
existingLabels = append(existingLabels, peer.DNSLabel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return existingLabels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Network, nil
|
||||||
|
}
|
||||||
|
|
||||||
type StoredAccount struct{}
|
type StoredAccount struct{}
|
||||||
|
|
||||||
// NewFileStore restores a store from the file located in the datadir
|
// NewFileStore restores a store from the file located in the datadir
|
||||||
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
|
|||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
|||||||
return account.Users[userID].Copy(), nil
|
return account.Users[userID].Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
accountID, ok := s.UserID2AccountID[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||||
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
|||||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||||
account, ok := s.Accounts[accountID]
|
account, ok := s.Accounts[accountID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found")
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
|
|||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
return "", status.NewSetupKeyNotFoundError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
|
|||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
return nil, status.NewPeerNotFoundError(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
||||||
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginPerformance(t *testing.T) {
|
func Test_LoginPerformance(t *testing.T) {
|
||||||
if os.Getenv("CI") == "true" {
|
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||||
t.Skip("Skipping on CI")
|
t.Skip("Skipping test on CI or Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||||
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
// {"M", 250, 1},
|
// {"M", 250, 1},
|
||||||
// {"L", 500, 1},
|
// {"L", 500, 1},
|
||||||
// {"XL", 750, 1},
|
// {"XL", 750, 1},
|
||||||
{"XXL", 2000, 1},
|
{"XXL", 5000, 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer mgmtServer.GracefulStop()
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
t.Logf("management setup complete, start registering peers")
|
||||||
|
|
||||||
var counter int32
|
var counter int32
|
||||||
var counterStart int32
|
var counterStart int32
|
||||||
var wg sync.WaitGroup
|
var wgAccount sync.WaitGroup
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
messageCalls := []func() error{}
|
messageCalls := []func() error{}
|
||||||
for j := 0; j < bc.accounts; j++ {
|
for j := 0; j < bc.accounts; j++ {
|
||||||
wg.Add(1)
|
wgAccount.Add(1)
|
||||||
|
var wgPeer sync.WaitGroup
|
||||||
go func(j int, counter *int32, counterStart *int32) {
|
go func(j int, counter *int32, counterStart *int32) {
|
||||||
defer wg.Done()
|
defer wgAccount.Done()
|
||||||
|
|
||||||
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
for i := 0; i < bc.peers; i++ {
|
for i := 0; i < bc.peers; i++ {
|
||||||
|
wgPeer.Add(1)
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("failed to generate key: %v", err)
|
t.Logf("failed to generate key: %v", err)
|
||||||
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
mu.Lock()
|
mu.Lock()
|
||||||
messageCalls = append(messageCalls, login)
|
messageCalls = append(messageCalls, login)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("failed to login peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.AddInt32(counterStart, 1)
|
go func(peerLogin PeerLogin, counterStart *int32) {
|
||||||
if *counterStart%100 == 0 {
|
defer wgPeer.Done()
|
||||||
t.Logf("registered %d peers", *counterStart)
|
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Logf("failed to login peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt32(counterStart, 1)
|
||||||
|
if *counterStart%100 == 0 {
|
||||||
|
t.Logf("registered %d peers", *counterStart)
|
||||||
|
}
|
||||||
|
}(peerLogin, counterStart)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
wgPeer.Wait()
|
||||||
|
|
||||||
|
t.Logf("Time for registration: %s", time.Since(startTime))
|
||||||
}(j, &counter, &counterStart)
|
}(j, &counter, &counterStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wgAccount.Wait()
|
||||||
|
|
||||||
t.Logf("prepared %d login calls", len(messageCalls))
|
t.Logf("prepared %d login calls", len(messageCalls))
|
||||||
testLoginPerformance(t, messageCalls)
|
testLoginPerformance(t, messageCalls)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var account *Account
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
|
||||||
account, err = am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
|
||||||
if am.idpManager != nil {
|
|
||||||
userdata, err := am.lookupUserInCache(ctx, userID, account)
|
|
||||||
if err == nil && userdata != nil {
|
|
||||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||||
// and the peer disconnects with a timeout and tries to register again.
|
// and the peer disconnects with a timeout and tries to register again.
|
||||||
// We just check if this machine has been registered before and reject the second registration.
|
// We just check if this machine has been registered before and reject the second registration.
|
||||||
// The connecting peer should be able to recover with a retry.
|
// The connecting peer should be able to recover with a retry.
|
||||||
_, err = account.FindPeerByPubKey(peer.Key)
|
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent := &activity.Event{
|
opEvent := &activity.Event{
|
||||||
Timestamp: time.Now().UTC(),
|
Timestamp: time.Now().UTC(),
|
||||||
AccountID: account.Id,
|
AccountID: accountID,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ephemeral bool
|
var newPeer *nbpeer.Peer
|
||||||
setupKeyName := ""
|
|
||||||
if !addedByUser {
|
|
||||||
// validate the setup key if adding with a key
|
|
||||||
sk, err := account.FindSetupKey(upperKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sk.IsValid() {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
var groupsToAdd []string
|
||||||
}
|
var setupKeyID string
|
||||||
|
var setupKeyName string
|
||||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
var ephemeral bool
|
||||||
opEvent.InitiatorID = sk.Id
|
if addedByUser {
|
||||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||||
ephemeral = sk.Ephemeral
|
if err != nil {
|
||||||
setupKeyName = sk.Name
|
return fmt.Errorf("failed to get user groups: %w", err)
|
||||||
} else {
|
}
|
||||||
opEvent.InitiatorID = userID
|
groupsToAdd = user.AutoGroups
|
||||||
opEvent.Activity = activity.PeerAddedByUser
|
opEvent.InitiatorID = userID
|
||||||
}
|
opEvent.Activity = activity.PeerAddedByUser
|
||||||
|
|
||||||
takenIps := account.getTakenIPs()
|
|
||||||
existingLabels := account.getPeerDNSLabels()
|
|
||||||
|
|
||||||
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.DNSLabel = newLabel
|
|
||||||
network := account.Network
|
|
||||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
|
||||||
|
|
||||||
newPeer := &nbpeer.Peer{
|
|
||||||
ID: xid.New().String(),
|
|
||||||
Key: peer.Key,
|
|
||||||
SetupKey: upperKey,
|
|
||||||
IP: nextIp,
|
|
||||||
Meta: peer.Meta,
|
|
||||||
Name: peer.Meta.Hostname,
|
|
||||||
DNSLabel: newLabel,
|
|
||||||
UserID: userID,
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
|
||||||
SSHEnabled: false,
|
|
||||||
SSHKey: peer.SSHKey,
|
|
||||||
LastLogin: registrationTime,
|
|
||||||
CreatedAt: registrationTime,
|
|
||||||
LoginExpirationEnabled: addedByUser,
|
|
||||||
Ephemeral: ephemeral,
|
|
||||||
Location: peer.Location,
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
|
||||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
|
||||||
} else {
|
} else {
|
||||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
// Validate the setup key
|
||||||
newPeer.Location.CityName = location.City.Names.En
|
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
|
||||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("failed to get setup key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add peer to 'All' group
|
if !sk.IsValid() {
|
||||||
group, err := account.GetGroupAll()
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
if err != nil {
|
}
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
group.Peers = append(group.Peers, newPeer.ID)
|
|
||||||
|
|
||||||
var groupsToAdd []string
|
opEvent.InitiatorID = sk.Id
|
||||||
if addedByUser {
|
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||||
groupsToAdd, err = account.getUserGroups(userID)
|
groupsToAdd = sk.AutoGroups
|
||||||
if err != nil {
|
ephemeral = sk.Ephemeral
|
||||||
return nil, nil, nil, err
|
setupKeyID = sk.Id
|
||||||
|
setupKeyName = sk.Name
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(groupsToAdd) > 0 {
|
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||||
for _, s := range groupsToAdd {
|
if am.idpManager != nil {
|
||||||
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||||
g.Peers = append(g.Peers, newPeer.ID)
|
if err == nil && userdata != nil {
|
||||||
|
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||||
|
|
||||||
if addedByUser {
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
}
|
}
|
||||||
user.updateLastLogin(newPeer.LastLogin)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newPeer.ID] = newPeer
|
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
||||||
account.Network.IncSerial()
|
if err != nil {
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
return fmt.Errorf("failed to get free IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registrationTime := time.Now().UTC()
|
||||||
|
newPeer = &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
|
Key: peer.Key,
|
||||||
|
SetupKey: upperKey,
|
||||||
|
IP: freeIP,
|
||||||
|
Meta: peer.Meta,
|
||||||
|
Name: peer.Meta.Hostname,
|
||||||
|
DNSLabel: freeLabel,
|
||||||
|
UserID: userID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
|
SSHEnabled: false,
|
||||||
|
SSHKey: peer.SSHKey,
|
||||||
|
LastLogin: registrationTime,
|
||||||
|
CreatedAt: registrationTime,
|
||||||
|
LoginExpirationEnabled: addedByUser,
|
||||||
|
Ephemeral: ephemeral,
|
||||||
|
Location: peer.Location,
|
||||||
|
}
|
||||||
|
opEvent.TargetID = newPeer.ID
|
||||||
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||||
|
if !addedByUser {
|
||||||
|
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||||
|
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||||
|
newPeer.Location.CityName = location.City.Names.En
|
||||||
|
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get account settings: %w", err)
|
||||||
|
}
|
||||||
|
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||||
|
|
||||||
|
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(groupsToAdd) > 0 {
|
||||||
|
for _, g := range groupsToAdd {
|
||||||
|
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if addedByUser {
|
||||||
|
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update user last login: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account is saved, we can release the lock
|
if newPeer == nil {
|
||||||
unlock()
|
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||||
unlock = nil
|
|
||||||
|
|
||||||
opEvent.TargetID = newPeer.ID
|
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
|
||||||
if !addedByUser {
|
|
||||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
|
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
return newPeer, networkMap, postureChecks, nil
|
return newPeer, networkMap, postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||||
|
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextIp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||||
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||||
|
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||||
|
for _, label := range existingLabels {
|
||||||
|
labelMap[label] = struct{}{}
|
||||||
|
}
|
||||||
|
return labelMap
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,20 +7,24 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(response.Checks))
|
assert.Equal(t, 1, len(response.Checks))
|
||||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerByUser(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
UserID: existingUserID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
LastLogin: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||||
|
assert.Equal(t, peer.UserID, existingUserID)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||||
|
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||||
|
|
||||||
|
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "existingSetupKey",
|
||||||
|
UserID: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||||
|
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||||
|
|
||||||
|
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
|
||||||
|
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "existingSetupKey",
|
||||||
|
UserID: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotContains(t, account.Peers, newPeer.ID)
|
||||||
|
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
|
||||||
|
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(0), account.Network.Serial)
|
||||||
|
|
||||||
|
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
|
||||||
|
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -33,6 +34,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
storeSqliteFileName = "store.db"
|
storeSqliteFileName = "store.db"
|
||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
@@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||||
var key SetupKey
|
var key SetupKey
|
||||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.AccountID == "" {
|
if key.AccountID == "" {
|
||||||
@@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.First(&user, idQueryCondition, userID)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
return nil, status.NewGetUserFromStoreError()
|
||||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
@@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
|||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found")
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
@@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
@@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
|||||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
|
|
||||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
|||||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||||
var key SetupKey
|
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
return "", status.NewSetupKeyNotFoundError()
|
||||||
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
|
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||||
|
var ipJSONStrings []string
|
||||||
|
|
||||||
|
// Fetch the IP addresses as JSON strings
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ?", accountID).
|
||||||
|
Pluck("ip", &ipJSONStrings)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the JSON strings to net.IP objects
|
||||||
|
ips := make([]net.IP, len(ipJSONStrings))
|
||||||
|
for i, ipJSON := range ipJSONStrings {
|
||||||
|
var ip net.IP
|
||||||
|
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
||||||
|
}
|
||||||
|
ips[i] = ip
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||||
|
var labels []string
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ?", accountID).
|
||||||
|
Pluck("dns_label", &labels)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return labels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||||
|
var accountNetwork AccountNetwork
|
||||||
|
|
||||||
|
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting network from store")
|
||||||
|
}
|
||||||
|
return accountNetwork.Network, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := s.db.First(&peer, "key = ?", peerKey)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &peer, nil
|
return &peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
|
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||||
var accountSettings AccountSettings
|
var accountSettings AccountSettings
|
||||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
||||||
}
|
}
|
||||||
return accountSettings.Settings, nil
|
return accountSettings.Settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
var user User
|
var user User
|
||||||
|
|
||||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
return status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
return status.Errorf(status.Internal, "issue getting user from store")
|
return status.NewGetUserFromStoreError()
|
||||||
}
|
}
|
||||||
|
|
||||||
user.LastLogin = lastLogin
|
user.LastLogin = lastLogin
|
||||||
|
|
||||||
return s.db.Save(user).Error
|
return s.db.Save(&user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
@@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
|
|||||||
|
|
||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||||
|
var setupKey SetupKey
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
return &setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||||
|
Where(idQueryCondition, setupKeyID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"used_times": gorm.Expr("used_times + 1"),
|
||||||
|
"last_used": time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||||
|
}
|
||||||
|
return status.Errorf(status.Internal, "issue finding group 'All'")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingPeerID := range group.Peers {
|
||||||
|
if existingPeerID == peerID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
|
||||||
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue updating group 'All'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return status.Errorf(status.NotFound, "group not found for account")
|
||||||
|
}
|
||||||
|
return status.Errorf(status.Internal, "issue finding group")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingPeerID := range group.Peers {
|
||||||
|
if existingPeerID == peerId {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = append(group.Peers, peerId)
|
||||||
|
|
||||||
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue updating group")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
|
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||||
|
tx := s.db.WithContext(ctx).Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
return tx.Error
|
||||||
|
}
|
||||||
|
repo := s.withTx(tx)
|
||||||
|
err := operation(repo)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit().Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||||
|
return &SqlStore{
|
||||||
|
db: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
require.Equal(t, id, user.PATs[id].ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []net.IP{}, takenIPs)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip1 := net.IP{1, 1, 1, 1}.To16()
|
||||||
|
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||||
|
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{}, labels)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1.domain.test",
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer2.domain.test",
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip := net.IP{100, 64, 0, 0}.To16()
|
||||||
|
assert.Equal(t, ip, network.Net.IP)
|
||||||
|
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
|
||||||
|
assert.Equal(t, "", network.Dns)
|
||||||
|
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
|
||||||
|
assert.Equal(t, uint64(0), network.Serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
|
||||||
|
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||||
|
assert.Equal(t, "Default key", setupKey.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, setupKey.UsedTimes)
|
||||||
|
|
||||||
|
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, setupKey.UsedTimes)
|
||||||
|
|
||||||
|
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||||
|
}
|
||||||
|
|||||||
@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
|
|||||||
func NewPeerLoginExpiredError() error {
|
func NewPeerLoginExpiredError() error {
|
||||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||||
|
func NewSetupKeyNotFoundError() error {
|
||||||
|
return Errorf(NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||||
|
func NewGetUserFromStoreError() error {
|
||||||
|
return Errorf(Internal, "issue getting user from store")
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,6 +27,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LockingStrength string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
||||||
|
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
||||||
|
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
||||||
|
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
||||||
|
)
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
GetAllAccounts(ctx context.Context) []*Account
|
GetAllAccounts(ctx context.Context) []*Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
@@ -41,7 +50,7 @@ type Store interface {
|
|||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
@@ -60,14 +69,24 @@ type Store interface {
|
|||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
// Close should close the store persisting all unsaved data.
|
// Close should close the store persisting all unsaved data.
|
||||||
Close(ctx context.Context) error
|
Close(ctx context.Context) error
|
||||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||||
// This is also a method of metrics.DataSource interface.
|
// This is also a method of metrics.DataSource interface.
|
||||||
GetStoreEngine() StoreEngine
|
GetStoreEngine() StoreEngine
|
||||||
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||||
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||||
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type StoreEngine string
|
type StoreEngine string
|
||||||
|
|||||||
120
management/server/testdata/extended-store.json
vendored
Normal file
120
management/server/testdata/extended-store.json
vendored
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
{
|
||||||
|
"Accounts": {
|
||||||
|
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||||
|
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
"CreatedBy": "",
|
||||||
|
"Domain": "test.com",
|
||||||
|
"DomainCategory": "private",
|
||||||
|
"IsDomainPrimaryAccount": true,
|
||||||
|
"SetupKeys": {
|
||||||
|
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||||
|
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||||
|
"AccountID": "",
|
||||||
|
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||||
|
"Name": "Default key",
|
||||||
|
"Type": "reusable",
|
||||||
|
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||||
|
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||||
|
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||||
|
"Revoked": false,
|
||||||
|
"UsedTimes": 0,
|
||||||
|
"LastUsed": "0001-01-01T00:00:00Z",
|
||||||
|
"AutoGroups": ["cfefqs706sqkneg59g2g"],
|
||||||
|
"UsageLimit": 0,
|
||||||
|
"Ephemeral": false
|
||||||
|
},
|
||||||
|
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
|
||||||
|
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||||
|
"AccountID": "",
|
||||||
|
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||||
|
"Name": "Faulty key with non existing group",
|
||||||
|
"Type": "reusable",
|
||||||
|
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||||
|
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||||
|
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||||
|
"Revoked": false,
|
||||||
|
"UsedTimes": 0,
|
||||||
|
"LastUsed": "0001-01-01T00:00:00Z",
|
||||||
|
"AutoGroups": ["abcd"],
|
||||||
|
"UsageLimit": 0,
|
||||||
|
"Ephemeral": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Network": {
|
||||||
|
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
||||||
|
"Net": {
|
||||||
|
"IP": "100.64.0.0",
|
||||||
|
"Mask": "//8AAA=="
|
||||||
|
},
|
||||||
|
"Dns": "",
|
||||||
|
"Serial": 0
|
||||||
|
},
|
||||||
|
"Peers": {},
|
||||||
|
"Users": {
|
||||||
|
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"AccountID": "",
|
||||||
|
"Role": "admin",
|
||||||
|
"IsServiceUser": false,
|
||||||
|
"ServiceUserName": "",
|
||||||
|
"AutoGroups": ["cfefqs706sqkneg59g3g"],
|
||||||
|
"PATs": {},
|
||||||
|
"Blocked": false,
|
||||||
|
"LastLogin": "0001-01-01T00:00:00Z"
|
||||||
|
},
|
||||||
|
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"AccountID": "",
|
||||||
|
"Role": "user",
|
||||||
|
"IsServiceUser": false,
|
||||||
|
"ServiceUserName": "",
|
||||||
|
"AutoGroups": null,
|
||||||
|
"PATs": {
|
||||||
|
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"UserID": "",
|
||||||
|
"Name": "",
|
||||||
|
"HashedToken": "SoMeHaShEdToKeN",
|
||||||
|
"ExpirationDate": "2023-02-27T00:00:00Z",
|
||||||
|
"CreatedBy": "user",
|
||||||
|
"CreatedAt": "2023-01-01T00:00:00Z",
|
||||||
|
"LastUsed": "2023-02-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Blocked": false,
|
||||||
|
"LastLogin": "0001-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Groups": {
|
||||||
|
"cfefqs706sqkneg59g4g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g4g",
|
||||||
|
"Name": "All",
|
||||||
|
"Peers": []
|
||||||
|
},
|
||||||
|
"cfefqs706sqkneg59g3g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g3g",
|
||||||
|
"Name": "AwesomeGroup1",
|
||||||
|
"Peers": []
|
||||||
|
},
|
||||||
|
"cfefqs706sqkneg59g2g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g2g",
|
||||||
|
"Name": "AwesomeGroup2",
|
||||||
|
"Peers": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Rules": null,
|
||||||
|
"Policies": [],
|
||||||
|
"Routes": null,
|
||||||
|
"NameServerGroups": null,
|
||||||
|
"DNSSettings": null,
|
||||||
|
"Settings": {
|
||||||
|
"PeerLoginExpirationEnabled": false,
|
||||||
|
"PeerLoginExpiration": 86400000000000,
|
||||||
|
"GroupsPropagationEnabled": false,
|
||||||
|
"JWTGroupsEnabled": false,
|
||||||
|
"JWTGroupsClaimName": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"InstallationID": ""
|
||||||
|
}
|
||||||
@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
|||||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) updateLastLogin(login time.Time) {
|
|
||||||
u.LastLogin = login
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||||
func (u *User) HasAdminPower() bool {
|
func (u *User) HasAdminPower() bool {
|
||||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||||
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
|||||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,37 +58,65 @@ func (m *Msg) Free() {
|
|||||||
m.bufPool.Put(m.bufPtr)
|
m.bufPool.Put(m.bufPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
|
||||||
|
// server and forwarding them to the upper layer content reader.
|
||||||
type connContainer struct {
|
type connContainer struct {
|
||||||
|
log *log.Entry
|
||||||
conn *Conn
|
conn *Conn
|
||||||
messages chan Msg
|
messages chan Msg
|
||||||
msgChanLock sync.Mutex
|
msgChanLock sync.Mutex
|
||||||
closed bool // flag to check if channel is closed
|
closed bool // flag to check if channel is closed
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &connContainer{
|
return &connContainer{
|
||||||
|
log: log,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *connContainer) writeMsg(msg Msg) {
|
func (cc *connContainer) writeMsg(msg Msg) {
|
||||||
cc.msgChanLock.Lock()
|
cc.msgChanLock.Lock()
|
||||||
defer cc.msgChanLock.Unlock()
|
defer cc.msgChanLock.Unlock()
|
||||||
|
|
||||||
if cc.closed {
|
if cc.closed {
|
||||||
|
msg.Free()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc.messages <- msg
|
|
||||||
|
select {
|
||||||
|
case cc.messages <- msg:
|
||||||
|
case <-cc.ctx.Done():
|
||||||
|
msg.Free()
|
||||||
|
default:
|
||||||
|
msg.Free()
|
||||||
|
cc.log.Infof("message queue is full")
|
||||||
|
// todo consider to close the connection
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *connContainer) close() {
|
func (cc *connContainer) close() {
|
||||||
|
cc.cancel()
|
||||||
|
|
||||||
cc.msgChanLock.Lock()
|
cc.msgChanLock.Lock()
|
||||||
defer cc.msgChanLock.Unlock()
|
defer cc.msgChanLock.Unlock()
|
||||||
|
|
||||||
if cc.closed {
|
if cc.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
close(cc.messages)
|
|
||||||
cc.closed = true
|
cc.closed = true
|
||||||
|
close(cc.messages)
|
||||||
|
|
||||||
|
for msg := range cc.messages {
|
||||||
|
msg.Free()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||||
@@ -120,8 +148,8 @@ type Client struct {
|
|||||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID, hashedStringId := messages.HashID(peerID)
|
||||||
return &Client{
|
c := &Client{
|
||||||
log: log.WithField("client_id", hashedStringId),
|
log: log.WithFields(log.Fields{"relay": serverURL}),
|
||||||
parentCtx: ctx,
|
parentCtx: ctx,
|
||||||
connectionURL: serverURL,
|
connectionURL: serverURL,
|
||||||
authTokenStore: authTokenStore,
|
authTokenStore: authTokenStore,
|
||||||
@@ -134,11 +162,13 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
|||||||
},
|
},
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[string]*connContainer),
|
||||||
}
|
}
|
||||||
|
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
c.log.Infof("connecting to relay server: %s", c.connectionURL)
|
c.log.Infof("connecting to relay server")
|
||||||
c.readLoopMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.readLoopMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
|
|
||||||
@@ -159,7 +189,7 @@ func (c *Client) Connect() error {
|
|||||||
c.wgReadLoop.Add(1)
|
c.wgReadLoop.Add(1)
|
||||||
go c.readLoop(c.relayConn)
|
go c.readLoop(c.relayConn)
|
||||||
|
|
||||||
c.log.Infof("relay connection established with: %s", c.connectionURL)
|
c.log.Infof("relay connection established")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,11 +211,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
return nil, ErrConnAlreadyExists
|
return nil, ErrConnAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("open connection to peer: %s", hashedStringID)
|
c.log.Infof("open connection to peer: %s", hashedStringID)
|
||||||
msgChannel := make(chan Msg, 2)
|
msgChannel := make(chan Msg, 100)
|
||||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||||
|
|
||||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +259,7 @@ func (c *Client) connect() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := conn.Close()
|
cErr := conn.Close()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
log.Errorf("failed to close connection: %s", cErr)
|
c.log.Errorf("failed to close connection: %s", cErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -240,19 +270,19 @@ func (c *Client) connect() error {
|
|||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake() error {
|
||||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to marshal auth message: %s", err)
|
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.relayConn.Write(msg)
|
_, err = c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to send auth message: %s", err)
|
c.log.Errorf("failed to send auth message: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||||
n, err := c.readWithTimeout(buf)
|
n, err := c.readWithTimeout(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to read auth response: %s", err)
|
c.log.Errorf("failed to read auth response: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,12 +293,12 @@ func (c *Client) handShake() error {
|
|||||||
|
|
||||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to determine message type: %s", err)
|
c.log.Errorf("failed to determine message type: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != messages.MsgTypeAuthResponse {
|
if msgType != messages.MsgTypeAuthResponse {
|
||||||
log.Errorf("unexpected message type: %s", msgType)
|
c.log.Errorf("unexpected message type: %s", msgType)
|
||||||
return fmt.Errorf("unexpected message type")
|
return fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +315,7 @@ func (c *Client) handShake() error {
|
|||||||
|
|
||||||
func (c *Client) readLoop(relayConn net.Conn) {
|
func (c *Client) readLoop(relayConn net.Conn) {
|
||||||
internallyStoppedFlag := newInternalStopFlag()
|
internallyStoppedFlag := newInternalStopFlag()
|
||||||
hc := healthcheck.NewReceiver()
|
hc := healthcheck.NewReceiver(c.log)
|
||||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -297,6 +327,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
buf := *bufPtr
|
buf := *bufPtr
|
||||||
n, errExit = relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
if errExit != nil {
|
if errExit != nil {
|
||||||
|
c.log.Infof("start to Relay read loop exit")
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||||
@@ -343,7 +374,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
|||||||
case messages.MsgTypeTransport:
|
case messages.MsgTypeTransport:
|
||||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||||
case messages.MsgTypeClose:
|
case messages.MsgTypeClose:
|
||||||
log.Debugf("relay connection close by server")
|
c.log.Debugf("relay connection close by server")
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -412,14 +443,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
|||||||
// todo: use buffer pool instead of create new transport msg.
|
// todo: use buffer pool instead of create new transport msg.
|
||||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to marshal transport message: %s", err)
|
c.log.Errorf("failed to marshal transport message: %s", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// the write always return with 0 length because the underling does not support the size feedback.
|
// the write always return with 0 length because the underling does not support the size feedback.
|
||||||
_, err = c.relayConn.Write(msg)
|
_, err = c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to write transport message: %s", err)
|
c.log.Errorf("failed to write transport message: %s", err)
|
||||||
}
|
}
|
||||||
return len(payload), err
|
return len(payload), err
|
||||||
}
|
}
|
||||||
@@ -433,12 +464,15 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
|||||||
}
|
}
|
||||||
c.log.Errorf("health check timeout")
|
c.log.Errorf("health check timeout")
|
||||||
internalStopFlag.set()
|
internalStopFlag.set()
|
||||||
_ = conn.Close() // ignore the err because the readLoop will handle it
|
if err := conn.Close(); err != nil {
|
||||||
|
// ignore the err handling because the readLoop will handle it
|
||||||
|
c.log.Warnf("failed to close connection: %s", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
case <-c.parentCtx.Done():
|
case <-c.parentCtx.Done():
|
||||||
err := c.close(true)
|
err := c.close(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to teardown connection: %s", err)
|
c.log.Errorf("failed to teardown connection: %s", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -464,8 +498,9 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
return fmt.Errorf("conn reference mismatch")
|
return fmt.Errorf("conn reference mismatch")
|
||||||
}
|
}
|
||||||
container.close()
|
c.log.Infof("free up connection to peer: %s", id)
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
|
container.close()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -478,10 +513,12 @@ func (c *Client) close(gracefullyExit bool) error {
|
|||||||
var err error
|
var err error
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
c.log.Warn("relay connection was already marked as not running")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.serviceIsRunning = false
|
c.serviceIsRunning = false
|
||||||
|
c.log.Infof("closing all peer connections")
|
||||||
c.closeAllConns()
|
c.closeAllConns()
|
||||||
if gracefullyExit {
|
if gracefullyExit {
|
||||||
c.writeCloseMsg()
|
c.writeCloseMsg()
|
||||||
@@ -489,8 +526,9 @@ func (c *Client) close(gracefullyExit bool) error {
|
|||||||
err = c.relayConn.Close()
|
err = c.relayConn.Close()
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
c.log.Infof("waiting for read loop to close")
|
||||||
c.wgReadLoop.Wait()
|
c.wgReadLoop.Wait()
|
||||||
c.log.Infof("relay connection closed with: %s", c.connectionURL)
|
c.log.Infof("relay connection closed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseNotDrainedChannel(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
idAlice := "alice"
|
||||||
|
idBob := "bob"
|
||||||
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
|
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create server: %s", err)
|
||||||
|
}
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(srvCfg)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Shutdown(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for servers to start
|
||||||
|
if err := waitForServerToStart(errChan); err != nil {
|
||||||
|
t.Fatalf("failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||||
|
err = clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Alice client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
||||||
|
err = clientBob.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientBob.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Bob client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := "hello bob, I am alice"
|
||||||
|
// the internal channel buffer size is 2. So we should overflow it
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
_, err = connAliceToBob.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for delivery
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
err = connBobToAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close channel: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func waitForServerToStart(errChan chan error) error {
|
func waitForServerToStart(errChan chan error) error {
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package client
|
|||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -17,8 +16,6 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
relayCleanupInterval = 60 * time.Second
|
relayCleanupInterval = 60 * time.Second
|
||||||
connectionTimeout = 30 * time.Second
|
|
||||||
maxConcurrentServers = 7
|
|
||||||
|
|
||||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||||
)
|
)
|
||||||
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
|
|||||||
}
|
}
|
||||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||||
|
|
||||||
totalServers := len(m.serverURLs)
|
sp := ServerPicker{
|
||||||
|
TokenStore: m.tokenStore,
|
||||||
successChan := make(chan *Client, 1)
|
PeerID: m.peerID,
|
||||||
errChan := make(chan error, len(m.serverURLs))
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sem := make(chan struct{}, maxConcurrentServers)
|
|
||||||
|
|
||||||
for _, url := range m.serverURLs {
|
|
||||||
sem <- struct{}{}
|
|
||||||
go func(url string) {
|
|
||||||
defer func() { <-sem }()
|
|
||||||
m.connect(m.ctx, url, successChan, errChan)
|
|
||||||
}(url)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var errCount int
|
client, err := sp.PickServer(m.ctx, m.serverURLs)
|
||||||
|
if err != nil {
|
||||||
for {
|
return err
|
||||||
select {
|
|
||||||
case client := <-successChan:
|
|
||||||
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
|
|
||||||
|
|
||||||
m.relayClient = client
|
|
||||||
|
|
||||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
|
||||||
m.relayClient.SetOnDisconnectListener(func() {
|
|
||||||
m.onServerDisconnected(client.connectionURL)
|
|
||||||
})
|
|
||||||
m.startCleanupLoop()
|
|
||||||
return nil
|
|
||||||
case err := <-errChan:
|
|
||||||
errCount++
|
|
||||||
log.Warnf("Connection attempt failed: %v", err)
|
|
||||||
if errCount == totalServers {
|
|
||||||
return errors.New("failed to connect to any relay server: all attempts failed")
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
m.relayClient = client
|
||||||
|
|
||||||
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
|
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||||
// TODO: abort the connection if another connection was successful
|
m.relayClient.SetOnDisconnectListener(func() {
|
||||||
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
|
m.onServerDisconnected(client.connectionURL)
|
||||||
if err := relayClient.Connect(); err != nil {
|
})
|
||||||
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
|
m.startCleanupLoop()
|
||||||
return
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case successChan <- relayClient:
|
|
||||||
// This client was the first to connect successfully
|
|
||||||
default:
|
|
||||||
if err := relayClient.Close(); err != nil {
|
|
||||||
log.Debugf("failed to close relay client: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||||
|
|||||||
98
relay/client/picker.go
Normal file
98
relay/client/picker.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
connectionTimeout = 30 * time.Second
|
||||||
|
maxConcurrentServers = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
type connResult struct {
|
||||||
|
RelayClient *Client
|
||||||
|
Url string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerPicker struct {
|
||||||
|
TokenStore *auth.TokenStore
|
||||||
|
PeerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
totalServers := len(urls)
|
||||||
|
|
||||||
|
connResultChan := make(chan connResult, totalServers)
|
||||||
|
successChan := make(chan connResult, 1)
|
||||||
|
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||||
|
|
||||||
|
for _, url := range urls {
|
||||||
|
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||||
|
concurrentLimiter <- struct{}{}
|
||||||
|
go func(url string) {
|
||||||
|
defer func() {
|
||||||
|
<-concurrentLimiter
|
||||||
|
}()
|
||||||
|
sp.startConnection(parentCtx, connResultChan, url)
|
||||||
|
}(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
go sp.processConnResults(connResultChan, successChan)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case cr, ok := <-successChan:
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("failed to connect to any relay server: all attempts failed")
|
||||||
|
}
|
||||||
|
log.Infof("chosen home Relay server: %s", cr.Url)
|
||||||
|
return cr.RelayClient, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||||
|
log.Infof("try to connecting to relay server: %s", url)
|
||||||
|
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
|
||||||
|
err := relayClient.Connect()
|
||||||
|
resultChan <- connResult{
|
||||||
|
RelayClient: relayClient,
|
||||||
|
Url: url,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||||
|
var hasSuccess bool
|
||||||
|
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||||
|
cr := <-resultChan
|
||||||
|
if cr.Err != nil {
|
||||||
|
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("connected to Relay server: %s", cr.Url)
|
||||||
|
|
||||||
|
if hasSuccess {
|
||||||
|
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
|
||||||
|
if err := cr.RelayClient.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hasSuccess = true
|
||||||
|
successChan <- cr
|
||||||
|
}
|
||||||
|
close(successChan)
|
||||||
|
}
|
||||||
31
relay/client/picker_test.go
Normal file
31
relay/client/picker_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||||
|
sp := ServerPicker{
|
||||||
|
TokenStore: nil,
|
||||||
|
PeerID: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
t.Errorf("PickServer() took too long to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,15 +23,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
metricsPort = 9090
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ListenAddress string
|
ListenAddress string
|
||||||
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
|
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
|
||||||
// it is a domain:port or ip:port
|
// it is a domain:port or ip:port
|
||||||
ExposedAddress string
|
ExposedAddress string
|
||||||
|
MetricsPort int
|
||||||
LetsencryptEmail string
|
LetsencryptEmail string
|
||||||
LetsencryptDataDir string
|
LetsencryptDataDir string
|
||||||
LetsencryptDomains []string
|
LetsencryptDomains []string
|
||||||
@@ -80,6 +77,7 @@ func init() {
|
|||||||
cobraConfig = &Config{}
|
cobraConfig = &Config{}
|
||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
|
||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
|
||||||
|
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
||||||
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
|
||||||
@@ -116,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to initialize log: %s", err)
|
return fmt.Errorf("failed to initialize log: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsServer, err := metrics.NewServer(metricsPort, "")
|
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("setup metrics: %v", err)
|
log.Debugf("setup metrics: %v", err)
|
||||||
return fmt.Errorf("setup metrics: %v", err)
|
return fmt.Errorf("setup metrics: %v", err)
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package healthcheck
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
heartbeatTimeout = healthCheckInterval + 3*time.Second
|
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// Receiver is a healthcheck receiver
|
// Receiver is a healthcheck receiver
|
||||||
@@ -14,23 +16,26 @@ var (
|
|||||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||||
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
|
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
|
||||||
type Receiver struct {
|
type Receiver struct {
|
||||||
OnTimeout chan struct{}
|
OnTimeout chan struct{}
|
||||||
|
log *log.Entry
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
heartbeat chan struct{}
|
heartbeat chan struct{}
|
||||||
alive bool
|
alive bool
|
||||||
|
attemptThreshold int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||||
func NewReceiver() *Receiver {
|
func NewReceiver(log *log.Entry) *Receiver {
|
||||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
r := &Receiver{
|
r := &Receiver{
|
||||||
OnTimeout: make(chan struct{}, 1),
|
OnTimeout: make(chan struct{}, 1),
|
||||||
ctx: ctx,
|
log: log,
|
||||||
ctxCancel: ctxCancel,
|
ctx: ctx,
|
||||||
heartbeat: make(chan struct{}, 1),
|
ctxCancel: ctxCancel,
|
||||||
|
heartbeat: make(chan struct{}, 1),
|
||||||
|
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||||
}
|
}
|
||||||
|
|
||||||
go r.waitForHealthcheck()
|
go r.waitForHealthcheck()
|
||||||
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
|
|||||||
defer r.ctxCancel()
|
defer r.ctxCancel()
|
||||||
defer close(r.OnTimeout)
|
defer close(r.OnTimeout)
|
||||||
|
|
||||||
|
failureCounter := 0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.heartbeat:
|
case <-r.heartbeat:
|
||||||
r.alive = true
|
r.alive = true
|
||||||
|
failureCounter = 0
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if r.alive {
|
if r.alive {
|
||||||
r.alive = false
|
r.alive = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
failureCounter++
|
||||||
|
if failureCounter < r.attemptThreshold {
|
||||||
|
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
|
||||||
|
continue
|
||||||
|
}
|
||||||
r.notifyTimeout()
|
r.notifyTimeout()
|
||||||
return
|
return
|
||||||
case <-r.ctx.Done():
|
case <-r.ctx.Done():
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
package healthcheck
|
package healthcheck
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewReceiver(t *testing.T) {
|
func TestNewReceiver(t *testing.T) {
|
||||||
heartbeatTimeout = 5 * time.Second
|
heartbeatTimeout = 5 * time.Second
|
||||||
r := NewReceiver()
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.OnTimeout:
|
case <-r.OnTimeout:
|
||||||
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewReceiverNotReceive(t *testing.T) {
|
func TestNewReceiverNotReceive(t *testing.T) {
|
||||||
heartbeatTimeout = 1 * time.Second
|
heartbeatTimeout = 1 * time.Second
|
||||||
r := NewReceiver()
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.OnTimeout:
|
case <-r.OnTimeout:
|
||||||
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewReceiverAck(t *testing.T) {
|
func TestNewReceiverAck(t *testing.T) {
|
||||||
heartbeatTimeout = 2 * time.Second
|
heartbeatTimeout = 2 * time.Second
|
||||||
r := NewReceiver()
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
|
||||||
r.Heartbeat()
|
r.Heartbeat()
|
||||||
|
|
||||||
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
|
|||||||
case <-time.After(3 * time.Second):
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||||
|
testsCases := []struct {
|
||||||
|
name string
|
||||||
|
threshold int
|
||||||
|
resetCounterOnce bool
|
||||||
|
}{
|
||||||
|
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||||
|
{"Custom attempt threshold", 3, false},
|
||||||
|
{"Should reset threshold once", 2, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testsCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
originalInterval := healthCheckInterval
|
||||||
|
originalTimeout := heartbeatTimeout
|
||||||
|
healthCheckInterval = 1 * time.Second
|
||||||
|
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||||
|
defer func() {
|
||||||
|
healthCheckInterval = originalInterval
|
||||||
|
heartbeatTimeout = originalTimeout
|
||||||
|
}()
|
||||||
|
//nolint:tenv
|
||||||
|
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||||
|
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||||
|
|
||||||
|
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||||
|
|
||||||
|
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||||
|
|
||||||
|
if tc.resetCounterOnce {
|
||||||
|
receiver.Heartbeat()
|
||||||
|
t.Logf("reset counter once")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receiver.OnTimeout:
|
||||||
|
if tc.resetCounterOnce {
|
||||||
|
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||||
|
}
|
||||||
|
case <-time.After(testTimeout):
|
||||||
|
if tc.resetCounterOnce {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("should have timed out before %s", testTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,12 +2,21 @@ package healthcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultAttemptThreshold = 1
|
||||||
|
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
healthCheckInterval = 25 * time.Second
|
healthCheckInterval = 25 * time.Second
|
||||||
healthCheckTimeout = 5 * time.Second
|
healthCheckTimeout = 20 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sender is a healthcheck sender
|
// Sender is a healthcheck sender
|
||||||
@@ -15,20 +24,25 @@ var (
|
|||||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||||
// It will also stop if the context is canceled
|
// It will also stop if the context is canceled
|
||||||
type Sender struct {
|
type Sender struct {
|
||||||
|
log *log.Entry
|
||||||
// HealthCheck is a channel to send health check signal to the peer
|
// HealthCheck is a channel to send health check signal to the peer
|
||||||
HealthCheck chan struct{}
|
HealthCheck chan struct{}
|
||||||
// Timeout is a channel to the health check signal is not received in a certain time
|
// Timeout is a channel to the health check signal is not received in a certain time
|
||||||
Timeout chan struct{}
|
Timeout chan struct{}
|
||||||
|
|
||||||
ack chan struct{}
|
ack chan struct{}
|
||||||
|
alive bool
|
||||||
|
attemptThreshold int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSender creates a new healthcheck sender
|
// NewSender creates a new healthcheck sender
|
||||||
func NewSender() *Sender {
|
func NewSender(log *log.Entry) *Sender {
|
||||||
hc := &Sender{
|
hc := &Sender{
|
||||||
HealthCheck: make(chan struct{}, 1),
|
log: log,
|
||||||
Timeout: make(chan struct{}, 1),
|
HealthCheck: make(chan struct{}, 1),
|
||||||
ack: make(chan struct{}, 1),
|
Timeout: make(chan struct{}, 1),
|
||||||
|
ack: make(chan struct{}, 1),
|
||||||
|
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return hc
|
return hc
|
||||||
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
|||||||
ticker := time.NewTicker(healthCheckInterval)
|
ticker := time.NewTicker(healthCheckInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
|
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||||
defer timeoutTimer.Stop()
|
defer timeoutTicker.Stop()
|
||||||
|
|
||||||
defer close(hc.HealthCheck)
|
defer close(hc.HealthCheck)
|
||||||
defer close(hc.Timeout)
|
defer close(hc.Timeout)
|
||||||
|
|
||||||
|
failureCounter := 0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
hc.HealthCheck <- struct{}{}
|
hc.HealthCheck <- struct{}{}
|
||||||
case <-timeoutTimer.C:
|
case <-timeoutTicker.C:
|
||||||
|
if hc.alive {
|
||||||
|
hc.alive = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
failureCounter++
|
||||||
|
if failureCounter < hc.attemptThreshold {
|
||||||
|
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
|
||||||
|
continue
|
||||||
|
}
|
||||||
hc.Timeout <- struct{}{}
|
hc.Timeout <- struct{}{}
|
||||||
return
|
return
|
||||||
case <-hc.ack:
|
case <-hc.ack:
|
||||||
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
|
failureCounter = 0
|
||||||
|
hc.alive = true
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||||
|
return healthCheckInterval + healthCheckTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAttemptThresholdFromEnv() int {
|
||||||
|
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||||
|
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||||
|
return defaultAttemptThreshold
|
||||||
|
}
|
||||||
|
return int(threshold)
|
||||||
|
}
|
||||||
|
return defaultAttemptThreshold
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ package healthcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
|
|||||||
func TestNewHealthPeriod(t *testing.T) {
|
func TestNewHealthPeriod(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
hc := NewSender()
|
hc := NewSender(log.WithContext(ctx))
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
|
|
||||||
iterations := 0
|
iterations := 0
|
||||||
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
|
|||||||
func TestNewHealthFailed(t *testing.T) {
|
func TestNewHealthFailed(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
hc := NewSender()
|
hc := NewSender(log.WithContext(ctx))
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewHealthcheckStop(t *testing.T) {
|
func TestNewHealthcheckStop(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
hc := NewSender()
|
hc := NewSender(log.WithContext(ctx))
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
|
|||||||
func TestTimeoutReset(t *testing.T) {
|
func TestTimeoutReset(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
hc := NewSender()
|
hc := NewSender(log.WithContext(ctx))
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
|
|
||||||
iterations := 0
|
iterations := 0
|
||||||
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
|
|||||||
t.Fatalf("is not exited")
|
t.Fatalf("is not exited")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||||
|
testsCases := []struct {
|
||||||
|
name string
|
||||||
|
threshold int
|
||||||
|
resetCounterOnce bool
|
||||||
|
}{
|
||||||
|
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||||
|
{"Custom attempt threshold", 3, false},
|
||||||
|
{"Should reset threshold once", 2, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testsCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
originalInterval := healthCheckInterval
|
||||||
|
originalTimeout := healthCheckTimeout
|
||||||
|
healthCheckInterval = 1 * time.Second
|
||||||
|
healthCheckTimeout = 500 * time.Millisecond
|
||||||
|
defer func() {
|
||||||
|
healthCheckInterval = originalInterval
|
||||||
|
healthCheckTimeout = originalTimeout
|
||||||
|
}()
|
||||||
|
|
||||||
|
//nolint:tenv
|
||||||
|
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||||
|
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sender := NewSender(log.WithField("test_name", tc.name))
|
||||||
|
go sender.StartHealthCheck(ctx)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
responded := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case _, ok := <-sender.HealthCheck:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tc.resetCounterOnce && !responded {
|
||||||
|
responded = true
|
||||||
|
sender.OnHCResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-sender.Timeout:
|
||||||
|
if tc.resetCounterOnce {
|
||||||
|
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||||
|
}
|
||||||
|
case <-time.After(testTimeout):
|
||||||
|
if tc.resetCounterOnce {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("should have timed out before %s", testTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:tenv
|
||||||
|
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envValue string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||||
|
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||||
|
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.envValue == "" {
|
||||||
|
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||||
|
} else {
|
||||||
|
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := getAttemptThresholdFromEnv()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (p *Peer) Work() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
hc := healthcheck.NewSender()
|
hc := healthcheck.NewSender(p.log)
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
go p.handleHealthcheckEvents(ctx, hc)
|
go p.handleHealthcheckEvents(ctx, hc)
|
||||||
|
|
||||||
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
|
|||||||
// connection.
|
// connection.
|
||||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||||
p.connMu.Lock()
|
p.connMu.Lock()
|
||||||
|
defer p.connMu.Unlock()
|
||||||
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||||
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Close() {
|
||||||
|
p.connMu.Lock()
|
||||||
defer p.connMu.Unlock()
|
defer p.connMu.Unlock()
|
||||||
|
|
||||||
|
if err := p.conn.Close(); err != nil {
|
||||||
|
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the peer ID
|
// String returns the peer ID
|
||||||
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||||
}
|
}
|
||||||
|
p.log.Info("peer connection closed due healthcheck timeout")
|
||||||
return
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,10 +19,14 @@ func NewStore() *Store {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds a peer to the store
|
// AddPeer adds a peer to the store
|
||||||
// todo: consider to close peer conn if the peer already exists
|
|
||||||
func (s *Store) AddPeer(peer *Peer) {
|
func (s *Store) AddPeer(peer *Peer) {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
|
odlPeer, ok := s.peers[peer.String()]
|
||||||
|
if ok {
|
||||||
|
odlPeer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
s.peers[peer.String()] = peer
|
s.peers[peer.String()] = peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,57 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockConn struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) Read(b []byte) (n int, err error) {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) Write(b []byte) (n int, err error) {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) LocalAddr() net.Addr {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) RemoteAddr() net.Addr {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) SetDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) SetReadDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
func TestStore_DeletePeer(t *testing.T) {
|
func TestStore_DeletePeer(t *testing.T) {
|
||||||
s := NewStore()
|
s := NewStore()
|
||||||
|
|
||||||
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
|||||||
|
|
||||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||||
|
|
||||||
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
|
conn := &mockConn{}
|
||||||
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
|
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||||
|
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||||
|
|
||||||
s.AddPeer(p1)
|
s.AddPeer(p1)
|
||||||
s.AddPeer(p2)
|
s.AddPeer(p2)
|
||||||
|
|||||||
@@ -29,12 +29,9 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
metricsPort = 9090
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
signalPort int
|
signalPort int
|
||||||
|
metricsPort int
|
||||||
signalLetsencryptDomain string
|
signalLetsencryptDomain string
|
||||||
signalSSLDir string
|
signalSSLDir string
|
||||||
defaultSignalSSLDir string
|
defaultSignalSSLDir string
|
||||||
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||||
|
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||||
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
|
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
|
||||||
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||||
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||||
|
|||||||
@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
|
|||||||
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
|
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
|
||||||
peer.Id, peer.StreamID, pp.StreamID)
|
peer.Id, peer.StreamID, pp.StreamID)
|
||||||
registry.Peers.Store(peer.Id, peer)
|
registry.Peers.Store(peer.Id, peer)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("peer registered [%s]", peer.Id)
|
log.Debugf("peer registered [%s]", peer.Id)
|
||||||
|
registry.metrics.ActivePeers.Add(context.Background(), 1)
|
||||||
|
|
||||||
// record time as milliseconds
|
// record time as milliseconds
|
||||||
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
|
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
|
||||||
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
|
|||||||
peer.Id, pp.StreamID, peer.StreamID)
|
peer.Id, pp.StreamID, peer.StreamID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
registry.metrics.ActivePeers.Add(context.Background(), -1)
|
||||||
|
log.Debugf("peer deregistered [%s]", peer.Id)
|
||||||
|
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
||||||
}
|
}
|
||||||
log.Debugf("peer deregistered [%s]", peer.Id)
|
|
||||||
|
|
||||||
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
|
|||||||
s.registry.Register(p)
|
s.registry.Register(p)
|
||||||
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
|
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
|
||||||
|
|
||||||
s.metrics.ActivePeers.Add(stream.Context(), 1)
|
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
} else {
|
} else {
|
||||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
|
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
|
||||||
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
|
|||||||
s.registry.Deregister(p)
|
s.registry.Deregister(p)
|
||||||
|
|
||||||
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
|
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
|
||||||
s.metrics.ActivePeers.Add(context.Background(), -1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
|
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
|
||||||
|
|||||||
87
util/file.go
87
util/file.go
@@ -10,51 +10,30 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||||
// The output JSON is pretty-formatted
|
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||||
func WriteJson(file string, obj interface{}) error {
|
|
||||||
|
|
||||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// make it pretty
|
err = EnforcePermission(file)
|
||||||
bs, err := json.MarshalIndent(obj, "", " ")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
return writeJson(file, obj, configDir, configFileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||||
|
// The output JSON is pretty-formatted
|
||||||
|
func WriteJson(file string, obj interface{}) error {
|
||||||
|
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tempFileName := tempFile.Name()
|
return writeJson(file, obj, configDir, configFileName)
|
||||||
// closing file ops as windows doesn't allow to move it
|
|
||||||
err = tempFile.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_, err = os.Stat(tempFileName)
|
|
||||||
if err == nil {
|
|
||||||
os.Remove(tempFileName)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = os.WriteFile(tempFileName, bs, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.Rename(tempFileName, file)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||||
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
||||||
|
|
||||||
|
// make it pretty
|
||||||
|
bs, err := json.MarshalIndent(obj, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFileName := tempFile.Name()
|
||||||
|
// closing file ops as windows doesn't allow to move it
|
||||||
|
err = tempFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_, err = os.Stat(tempFileName)
|
||||||
|
if err == nil {
|
||||||
|
os.Remove(tempFileName)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = os.WriteFile(tempFileName, bs, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Rename(tempFileName, file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func openOrCreateFile(file string) (*os.File, error) {
|
func openOrCreateFile(file string) (*os.File, error) {
|
||||||
s, err := os.Stat(file)
|
s, err := os.Stat(file)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := os.MkdirAll(configDir, 0750)
|
err := os.MkdirAll(configDir, 0750)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
return configDir, configFileName, err
|
return configDir, configFileName, err
|
||||||
}
|
}
|
||||||
|
|||||||
7
util/permission.go
Normal file
7
util/permission.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
func EnforcePermission(dirPath string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
86
util/permission_windows.go
Normal file
86
util/permission_windows.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
securityFlags = windows.OWNER_SECURITY_INFORMATION |
|
||||||
|
windows.GROUP_SECURITY_INFORMATION |
|
||||||
|
windows.DACL_SECURITY_INFORMATION |
|
||||||
|
windows.PROTECTED_DACL_SECURITY_INFORMATION
|
||||||
|
)
|
||||||
|
|
||||||
|
func EnforcePermission(file string) error {
|
||||||
|
dirPath := filepath.Dir(file)
|
||||||
|
|
||||||
|
user, group, err := sids()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
explicitAccess := []windows.EXPLICIT_ACCESS{
|
||||||
|
{
|
||||||
|
AccessPermissions: windows.GENERIC_ALL,
|
||||||
|
AccessMode: windows.SET_ACCESS,
|
||||||
|
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||||
|
Trustee: windows.TRUSTEE{
|
||||||
|
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||||
|
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||||
|
TrusteeType: windows.TRUSTEE_IS_USER,
|
||||||
|
TrusteeValue: windows.TrusteeValueFromSID(user),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
AccessPermissions: windows.GENERIC_ALL,
|
||||||
|
AccessMode: windows.SET_ACCESS,
|
||||||
|
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||||
|
Trustee: windows.TRUSTEE{
|
||||||
|
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||||
|
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||||
|
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
|
||||||
|
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sids() (*windows.SID, *windows.SID, error) {
|
||||||
|
var token windows.Token
|
||||||
|
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := token.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close process token: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tu, err := token.GetTokenUser()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pg, err := token.GetTokenPrimaryGroup()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tu.User.Sid, pg.PrimaryGroup, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user