Compare commits

...

23 Commits

Author SHA1 Message Date
Pascal Fischer
ad8459ea2f add mysql support [WIP] 2024-09-27 13:44:50 +02:00
Zoltan Papp
4ebf6e1c4c [client] Close the remote conn in proxy (#2626)
Port the conn close call to eBPF proxy
2024-09-25 18:50:10 +02:00
pascal-fischer
1e4a0f77e2 Add get DB method to store (#2650) 2024-09-25 18:22:27 +02:00
Viktor Liu
b51d75204b [client] Anonymize relay address in status peers view (#2640) 2024-09-24 20:58:18 +02:00
Viktor Liu
e7d52c8c95 [client] Fix error count formatting (#2641) 2024-09-24 20:57:56 +02:00
Viktor Liu
ab82302c95 [client] Remove usage of custom dialer for localhost (#2639)
* Downgrade error log level for network monitor warnings

* Do not use custom dialer for localhost
2024-09-24 12:29:15 +02:00
pascal-fischer
d47be154ea [misc] Fix ip range posture check example (#2628) 2024-09-23 10:02:03 +02:00
Bethuel Mmbaga
35c892aea3 [management] Restrict accessible peers to user-owned peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add service user test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* reuse account from token

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* return error when peer not found

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 12:36:58 +03:00
Zoltan Papp
fc4b37f7bc Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries

If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
2024-09-19 13:49:28 +02:00
Zoltan Papp
6f0fd1d1b3 - Increase queue size and drop the overflowed messages (#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
2024-09-19 13:49:09 +02:00
Zoltan Papp
28cbb4b70f [client] Cancel the context of wg watcher when the go routine exit (#2612) 2024-09-17 12:10:17 +02:00
Zoltan Papp
1104c9c048 [client] Fix race condition while read/write conn status in peer conn (#2607) 2024-09-17 11:15:14 +02:00
Maycon Santos
5bc601111d [relay] Add health check attempt threshold (#2609)
* Add health check attempt threshold for receiver

* Add health check attempt threshold for sender
2024-09-17 10:04:17 +02:00
Zoltan Papp
b74951f29e [client] Enforce permissions on Win (#2568)
Enforce folder permission on Windows, giving only administrators and system access to the NetBird folder.
2024-09-16 22:42:37 +02:00
Zoltan Papp
97e10e440c Fix leaked server connections (#2596)
Fix leaked server connections

close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Add logging

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-16 16:11:10 +02:00
pascal-fischer
6c50b0c84b [management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
2024-09-16 15:47:03 +02:00
pascal-fischer
730dd1733e [signal] Fix signal active peers metrics (#2591) 2024-09-15 16:46:55 +02:00
Bethuel Mmbaga
82739e2832 [management] fix legacy decrypting of empty values (#2595)
* allow legacy decrypting on empty values

* validate source size and padding limits

* added tests

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-15 16:22:46 +02:00
Maycon Santos
fa7767e612 Fix get management and signal state race condition (#2570)
* Fix get management and signal state race condition

* fix get full status lock
2024-09-15 16:07:26 +02:00
benniekiss
f1171198de [management] Add command flag to set metrics port for signal and relay service, and update management port (#2599)
* add flags to customize metrics port for relay and signal

* change management default metrics port to match other services
2024-09-14 10:34:32 +02:00
Zoltan Papp
9e041b7f82 Fix blocked net.Conn Close call (#2600) 2024-09-14 10:27:37 +02:00
Zoltan Papp
b4c8cf0a67 Change heartbeat timeout (#2598) 2024-09-14 10:12:54 +02:00
Carlos Hernandez
1ef51a4ffa [client] Ensure engine is stopped before starting it back (#2565)
Before starting a new instance of the engine, check if it is nil and stop the current instance
2024-09-13 16:46:59 +02:00
63 changed files with 2921 additions and 711 deletions

View File

@@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04

View File

@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}

View File

@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))

View File

@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = WriteOutConfig(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input)
}

View File

@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done()

View File

@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err)
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}

View File

@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
statusRelay ConnStatus
statusICE ConnStatus
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected,
statusICE: StatusDisconnected,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE == StatusDisconnected {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE = StatusConnected
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
@@ -484,16 +484,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
conn.workerRelay.EnableWgWatcher(conn.ctx)
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay
}
changed := conn.statusICE != newState && newState != StatusConnecting
conn.statusICE = newState
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
@@ -518,18 +518,22 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay = StatusConnected
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE == StatusConnected {
if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
@@ -551,7 +555,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if err := wgProxy.CloseConn(); err != nil {
@@ -560,6 +563,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil
}
changed := conn.statusRelay != StatusDisconnected
conn.statusRelay = StatusDisconnected
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
}
func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay = StatusDisconnected
conn.statusICE = StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
conn.statusICE.Set(StatusDisconnected)
peerState := State{
PubKey: conn.config.Key,
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
}
func (conn *Conn) isRelayed() bool {
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
return false
}
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
}
func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
return StatusConnected
}
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting
}
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay != StatusConnected {
if conn.statusRelay.Get() != StatusConnected {
return false
}
}
@@ -771,13 +775,12 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn()
if err != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
}

View File

@@ -1,6 +1,10 @@
package peer
import log "github.com/sirupsen/logrus"
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
// StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
)
// ConnStatus describe the status of a peer's connection
type ConnStatus int
type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string {
switch s {

View File

@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")

View File

@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}

View File

@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
}
func (w *WorkerRelay) DisableWgWatcher() {
@@ -158,21 +158,22 @@ func (w *WorkerRelay) CloseConn() {
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,6 +1,6 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
loopbackAddr = "127.0.0.1"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn transport.UDPConn
lastUsedPort uint16
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
ctxCancel context.CancelFunc
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy
}
// listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{
Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
}
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
@@ -91,108 +95,112 @@ func (p *WGEBPFProxy) listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(wgEndpointPort, turnConn)
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
// Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error
if p.conn != nil {
err1 = p.conn.Close()
if p.ctx != nil && p.ctx.Err() != nil {
//nolint
return nil
}
err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil {
err3 = p.rawConn.Close()
p.ctxCancel()
var result *multierror.Error
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
}
if err1 != nil {
return err1
if err := p.ebpfManager.FreeWGProxy(); err != nil {
result = multierror.Append(result, err)
}
if err2 != nil {
return err2
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return err3
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
var err error
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
select {
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil {
return
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
if p.ctx.Err() == nil {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
}
return nil
}
if _, err := conn.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
}
return nil
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
@@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
_, ok := p.turnConnStore[turnConnID]
if ok {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
}
delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {

View File

@@ -1,14 +1,13 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)

View File

@@ -0,0 +1,44 @@
//go:build linux && !android
package ebpf
import (
"context"
"fmt"
"net"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}

View File

@@ -1,22 +0,0 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@@ -3,20 +3,26 @@
package wgproxy
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
)
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
type Factory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy
return f
}
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
p := &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return p
}
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -2,8 +2,20 @@
package wgproxy
import "context"
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
type Factory struct {
wgPort int
}
func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}
func (w *Factory) GetProxy() Proxy {
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
return nil
}

View File

@@ -1,12 +1,12 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the Turn connection and the WireGuard
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(turnConn net.Conn) (net.Addr, error)
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error
Free() error
}

View File

@@ -0,0 +1,128 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,120 +0,0 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
}

View File

@@ -0,0 +1,146 @@
package usp
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
closeMu sync.Mutex
closed bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}

4
go.mod
View File

@@ -95,9 +95,10 @@ require (
golang.org/x/term v0.21.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
gorm.io/gorm v1.25.7
nhooyr.io/websocket v1.8.11
)
@@ -151,6 +152,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-text/render v0.1.0 // indirect
github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect

8
go.sum
View File

@@ -238,6 +238,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
@@ -1224,12 +1226,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g=
gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=

View File

@@ -54,7 +54,7 @@ func Execute() error {
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")

View File

@@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
@@ -115,12 +114,23 @@ func pkcs5Padding(ciphertext []byte) []byte {
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -1,6 +1,7 @@
package sqlite
import (
"bytes"
"testing"
)
@@ -95,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
return nil, fmt.Errorf("account not found")
return nil, status.NewPeerNotFoundError(peerId)
}
type MocAccountManager struct {

View File

@@ -2,6 +2,8 @@ package server
import (
"context"
"errors"
"net"
"os"
"path/filepath"
"strings"
@@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return account, nil
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return "", status.NewSetupKeyNotFoundError()
}
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()

View File

@@ -950,7 +950,7 @@ components:
type: array
items:
type: string
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action:
description: Action to take upon policy match
type: string

View File

@@ -7,8 +7,6 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return
}
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
return
}
if peer.UserID != user.Id {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
@@ -12,20 +13,30 @@ import (
"time"
"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
type ctxKey string
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
peers[0].ID: peers[0],
peers[1].ID: peers[1],
},
Peers: peersMap,
Users: map[string]*server.User{
"test_user": user,
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
}, user, nil
}
return account, account.Users[claims.UserId], nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
})
}
}
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

View File

@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" {
t.Skip("Skipping on CI")
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
{"XXL", 2000, 1},
{"XXL", 5000, 1},
}
log.SetOutput(io.Discard)
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
}
defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32
var counterStart int32
var wg sync.WaitGroup
var wgAccount sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
wg.Add(1)
wgAccount.Add(1)
var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) {
defer wg.Done()
defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return
}
startTime := time.Now()
for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
go func(peerLogin PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
}
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart)
}
wg.Wait()
wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)

View File

@@ -11,6 +11,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
AccountID: account.Id,
AccountID: accountID,
}
var ephemeral bool
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, nil, err
}
var newPeer *nbpeer.Peer
if !sk.IsValid() {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
ephemeral = sk.Ephemeral
setupKeyName = sk.Name
} else {
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
}
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
var groupsToAdd []string
if addedByUser {
groupsToAdd, err = account.getUserGroups(userID)
if err != nil {
return nil, nil, nil, err
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
}
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.ID)
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser {
user, err := account.FindUser(userID)
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
return fmt.Errorf("failed to get free DNS label: %w", err)
}
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
// Account is saved, we can release the lock
unlock()
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, peer)
postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err
}
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@@ -7,20 +7,24 @@ import (
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
@@ -14,6 +15,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
@@ -33,6 +35,7 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -415,13 +418,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@@ -474,15 +476,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
return nil, status.NewUserNotFoundError(userID)
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@@ -535,7 +537,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -595,7 +597,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -612,12 +614,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -631,12 +632,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -650,12 +650,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -677,61 +676,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
return status.NewUserNotFoundError(userID)
}
return status.Errorf(status.Internal, "issue getting user from store")
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -790,6 +845,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
}
// NewMysqlStore creates a new MySql store.
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(mysql.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
return NewSqlStore(ctx, db, MySqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
@@ -807,6 +872,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store,
return NewPostgresqlStore(ctx, dsn, metrics)
}
// newMySqlStore initializes a new MySql store.
func newMySqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(mySqlDsnEnv)
if !ok {
return nil, fmt.Errorf("%s is not set", mySqlDsnEnv)
}
return NewMysqlStore(ctx, dsn, metrics)
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewSqliteStore(ctx, dataDir, metrics)
@@ -850,3 +924,127 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}

View File

@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@@ -27,6 +27,15 @@ import (
"github.com/netbirdio/netbird/route"
)
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
@@ -41,7 +50,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
@@ -60,14 +69,24 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
type StoreEngine string
@@ -76,8 +95,10 @@ const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
PostgresStoreEngine StoreEngine = "postgres"
MySqlStoreEngine StoreEngine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mySqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
func getStoreEngineFromEnv() StoreEngine {
@@ -88,11 +109,12 @@ func getStoreEngineFromEnv() StoreEngine {
}
value := StoreEngine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine {
switch value {
case SqliteStoreEngine, PostgresStoreEngine, MySqlStoreEngine:
return value
default:
return SqliteStoreEngine
}
return SqliteStoreEngine
}
// getStoreEngine determines the store engine to use.
@@ -139,6 +161,9 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
case PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics)
case MySqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMySqlStore(ctx, metrics)
default:
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
}

View File

@@ -0,0 +1,120 @@
{
"Accounts": {
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
"CreatedBy": "",
"Domain": "test.com",
"DomainCategory": "private",
"IsDomainPrimaryAccount": true,
"SetupKeys": {
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"Name": "Default key",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["cfefqs706sqkneg59g2g"],
"UsageLimit": 0,
"Ephemeral": false
},
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"Name": "Faulty key with non existing group",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["abcd"],
"UsageLimit": 0,
"Ephemeral": false
}
},
"Network": {
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": "",
"Serial": 0
},
"Peers": {},
"Users": {
"edafee4e-63fb-11ec-90d6-0242ac120003": {
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "admin",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": ["cfefqs706sqkneg59g3g"],
"PATs": {},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
},
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "user",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": null,
"PATs": {
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
"UserID": "",
"Name": "",
"HashedToken": "SoMeHaShEdToKeN",
"ExpirationDate": "2023-02-27T00:00:00Z",
"CreatedBy": "user",
"CreatedAt": "2023-01-01T00:00:00Z",
"LastUsed": "2023-02-01T00:00:00Z"
}
},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": {
"cfefqs706sqkneg59g4g": {
"ID": "cfefqs706sqkneg59g4g",
"Name": "All",
"Peers": []
},
"cfefqs706sqkneg59g3g": {
"ID": "cfefqs706sqkneg59g3g",
"Name": "AwesomeGroup1",
"Peers": []
},
"cfefqs706sqkneg59g2g": {
"ID": "cfefqs706sqkneg59g2g",
"Name": "AwesomeGroup2",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,
"Settings": {
"PeerLoginExpirationEnabled": false,
"PeerLoginExpiration": 86400000000000,
"GroupsPropagationEnabled": false,
"JWTGroupsEnabled": false,
"JWTGroupsClaimName": ""
}
}
},
"InstallationID": ""
}

View File

@@ -70,7 +70,7 @@ type User struct {
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
LastLogin time.Time `gorm:"type:TIMESTAMP;null;default:null"`
// CreatedAt records the time the user was created
CreatedAt time.Time
@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
func (u *User) updateLastLogin(login time.Time) {
u.LastLogin = login
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}

View File

@@ -58,37 +58,65 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct {
log *log.Entry
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
return &connContainer{
log: log,
conn: conn,
messages: messages,
ctx: ctx,
cancel: cancel,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
msg.Free()
return
}
cc.messages <- msg
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
}
func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@@ -120,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
@@ -134,11 +162,13 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
},
conns: make(map[string]*connContainer),
}
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
@@ -159,7 +189,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL)
c.log.Infof("relay connection established")
return nil
}
@@ -181,11 +211,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists
}
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil
}
@@ -229,7 +259,7 @@ func (c *Client) connect() error {
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
c.log.Errorf("failed to close connection: %s", cErr)
}
return err
}
@@ -240,19 +270,19 @@ func (c *Client) connect() error {
func (c *Client) handShake() error {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
log.Errorf("failed to marshal auth message: %s", err)
c.log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send auth message: %s", err)
c.log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read auth response: %s", err)
c.log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -263,12 +293,12 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
c.log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType)
c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
@@ -285,7 +315,7 @@ func (c *Client) handShake() error {
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
@@ -297,6 +327,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
@@ -343,7 +374,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
log.Debugf("relay connection close by server")
c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
@@ -412,14 +443,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
c.log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
@@ -433,12 +464,15 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
c.log.Errorf("failed to teardown connection: %s", err)
}
return
}
@@ -464,8 +498,9 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
container.close()
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id)
container.close()
return nil
}
@@ -478,10 +513,12 @@ func (c *Client) close(gracefullyExit bool) error {
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil
}
c.serviceIsRunning = false
c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
@@ -489,8 +526,9 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close()
c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL)
c.log.Infof("relay connection closed")
return err
}

View File

@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
}
}
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:

View File

@@ -3,7 +3,6 @@ package client
import (
"container/list"
"context"
"errors"
"fmt"
"net"
"reflect"
@@ -17,8 +16,6 @@ import (
var (
relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs)
successChan := make(chan *Client, 1)
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
sp := ServerPicker{
TokenStore: m.tokenStore,
PeerID: m.peerID,
}
var errCount int
for {
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil {
return err
}
}
m.relayClient = client
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be

98
relay/client/picker.go Normal file
View File

@@ -0,0 +1,98 @@
package client
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect()
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@@ -0,0 +1,31 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}

View File

@@ -23,15 +23,12 @@ import (
"github.com/netbirdio/netbird/util"
)
const (
metricsPort = 9090
)
type Config struct {
ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
ExposedAddress string
MetricsPort int
LetsencryptEmail string
LetsencryptDataDir string
LetsencryptDomains []string
@@ -80,6 +77,7 @@ func init() {
cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
@@ -116,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
metricsServer, err := metrics.NewServer(metricsPort, "")
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)

View File

@@ -3,10 +3,12 @@ package healthcheck
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
heartbeatTimeout = healthCheckInterval + 3*time.Second
heartbeatTimeout = healthCheckInterval + 10*time.Second
)
// Receiver is a healthcheck receiver
@@ -14,23 +16,26 @@ var (
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
OnTimeout chan struct{}
log *log.Entry
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
attemptThreshold int
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver() *Receiver {
func NewReceiver(log *log.Entry) *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
OnTimeout: make(chan struct{}, 1),
log: log,
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck()
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
defer r.ctxCancel()
defer close(r.OnTimeout)
failureCounter := 0
for {
select {
case <-r.heartbeat:
r.alive = true
failureCounter = 0
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
failureCounter++
if failureCounter < r.attemptThreshold {
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():

View File

@@ -1,13 +1,18 @@
package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
r.Heartbeat()
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
case <-time.After(3 * time.Second):
}
}
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
defer func() {
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
case <-receiver.OnTimeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,12 +2,21 @@ package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 5 * time.Second
healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
@@ -15,20 +24,25 @@ var (
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
ack chan struct{}
ack chan struct{}
alive bool
attemptThreshold int
}
// NewSender creates a new healthcheck sender
func NewSender() *Sender {
func NewSender(log *log.Entry) *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
defer timeoutTimer.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
failureCounter := 0
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTimer.C:
case <-timeoutTicker.C:
if hc.alive {
hc.alive = false
continue
}
failureCounter++
if failureCounter < hc.attemptThreshold {
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
continue
}
hc.Timeout <- struct{}{}
return
case <-hc.ack:
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
failureCounter = 0
hc.alive = true
case <-ctx.Done():
return
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,9 +2,12 @@ package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
t.Fatalf("is not exited")
}
}
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
defer func() {
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
go func() {
responded := false
for {
select {
case <-ctx.Done():
return
case _, ok := <-sender.HealthCheck:
if !ok {
return
}
if tc.resetCounterOnce && !responded {
responded = true
sender.OnHCResponse()
}
}
}
}()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -49,7 +49,7 @@ func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := healthcheck.NewSender()
hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc)
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String())
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
// String returns the peer ID
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
p.log.Info("peer connection closed due healthcheck timeout")
return
case <-ctx.Done():
return

View File

@@ -19,10 +19,14 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer
}

View File

@@ -2,13 +2,57 @@ package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)

View File

@@ -29,12 +29,9 @@ import (
"google.golang.org/grpc/keepalive"
)
const (
metricsPort = 9090
)
var (
signalPort int
metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")

View File

@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
return
}
log.Debugf("peer registered [%s]", peer.Id)
registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID)
return
}
registry.metrics.ActivePeers.Add(context.Background(), -1)
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}

View File

@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
s.metrics.ActivePeers.Add(stream.Context(), 1)
return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.metrics.ActivePeers.Add(context.Background(), -1)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {

View File

@@ -10,51 +10,30 @@ import (
log "github.com/sirupsen/logrus"
)
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
err = EnforcePermission(file)
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
return writeJson(file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
return writeJson(file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
}
err := os.MkdirAll(configDir, 0750)
if err != nil {
return "", "", err
}
return configDir, configFileName, err
}

7
util/permission.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build !windows
package util
func EnforcePermission(dirPath string) error {
return nil
}

View File

@@ -0,0 +1,86 @@
package util
import (
"path/filepath"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
securityFlags = windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION
)
func EnforcePermission(file string) error {
dirPath := filepath.Dir(file)
user, group, err := sids()
if err != nil {
return err
}
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return err
}
explicitAccess := []windows.EXPLICIT_ACCESS{
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_USER,
TrusteeValue: windows.TrusteeValueFromSID(user),
},
},
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
},
},
}
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
if err != nil {
return err
}
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
}
func sids() (*windows.SID, *windows.SID, error) {
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
if err != nil {
return nil, nil, err
}
defer func() {
if err := token.Close(); err != nil {
log.Errorf("failed to close process token: %v", err)
}
}()
tu, err := token.GetTokenUser()
if err != nil {
return nil, nil, err
}
pg, err := token.GetTokenPrimaryGroup()
if err != nil {
return nil, nil, err
}
return tu.User.Sid, pg.PrimaryGroup, nil
}