Compare commits

...

45 Commits

Author SHA1 Message Date
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
Viktor Liu
e3b76448f3 Fix ICE endpoint remote port in status command (#1851) 2024-04-16 14:01:59 +02:00
Viktor Liu
e0de86d6c9 Use fixed activity codes (#1846)
* Add duplicate constants check
2024-04-15 14:15:46 +02:00
Zoltan Papp
5204d07811 Pass integrated validator for API (#1814)
Pass integrated validator for API handler
2024-04-15 12:08:38 +02:00
Viktor Liu
5ea24ba56e Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) 2024-04-12 17:53:07 +02:00
Viktor Liu
d30cf8706a Allow disabling custom routing (#1840) 2024-04-12 16:53:11 +02:00
Viktor Liu
15a2feb723 Use fixed preference for rules (#1836) 2024-04-12 16:07:03 +02:00
Viktor Liu
91b2f9fc51 Use route active store (#1834) 2024-04-12 15:22:40 +02:00
Carlos Hernandez
76702c8a09 Add safe read/write to route map (#1760) 2024-04-11 22:12:23 +02:00
Viktor Liu
061f673a4f Don't use the custom dialer as non-root (#1823) 2024-04-11 15:29:03 +02:00
Zoltan Papp
9505805313 Rename variable (#1829) 2024-04-11 14:08:03 +02:00
Maycon Santos
704c67dec8 Allow owners that did not create the account to delete it (#1825)
Sometimes the Owner role will be passed to new users, and they need to be able to delete the account
2024-04-11 10:02:51 +02:00
pascal-fischer
3ed2f08f3c Add latency based routing (#1732)
Now that we have the latency between peers available we can use this data to consider when choosing the best route. This way the route with the routing peer with the lower latency will be preferred over others with the same target network.
2024-04-09 21:20:02 +02:00
Maycon Santos
4c83408f27 Add log-level to the management's docker service command (#1820) 2024-04-09 21:00:43 +02:00
Viktor Liu
90bd39c740 Log panics (#1818) 2024-04-09 20:27:27 +02:00
Maycon Santos
dd0cf41147 Auto restart Windows agent daemon service (#1819)
This enables auto restart of the windows agent daemon service on event of failure
2024-04-09 20:10:59 +02:00
pascal-fischer
22b2caffc6 Remove dns based cloud detection (#1812)
* remove dns based cloud checks

* remove dns based cloud checks
2024-04-09 19:01:31 +02:00
Viktor Liu
c1f66d1354 Retry macOS route command (#1817) 2024-04-09 15:27:19 +02:00
Viktor Liu
ac0fe6025b Fix routing issues with MacOS (#1815)
* Handle zones properly

* Use host routes for single IPs 

* Add GOOS and GOARCH to startup log

* Log powershell command
2024-04-09 13:25:14 +02:00
verytrap
c28657710a Fix function names in comments (#1816)
Signed-off-by: verytrap <wangqiuyue@outlook.com>
2024-04-09 13:18:38 +02:00
Maycon Santos
3875c29f6b Revert "Rollback new routing functionality (#1805)" (#1813)
This reverts commit 9f32ccd453.
2024-04-08 18:56:52 +02:00
Viktor Liu
9f32ccd453 Rollback new routing functionality (#1805) 2024-04-05 20:38:49 +02:00
trax
1d1d057e7d Change the dashboard image pull from wiretrustee to netbirdio (#1804) 2024-04-05 13:51:28 +02:00
Viktor Liu
3461b1bb90 Expect correct conn type (#1801) 2024-04-05 00:10:32 +02:00
59 changed files with 1033 additions and 413 deletions

View File

@@ -33,6 +33,10 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:

View File

@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
} }
} }
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig) s, err := newSVC(newProgram(ctx, cancel), svcConfig)

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"runtime"
"runtime/debug"
"strings" "strings"
"time" "time"
@@ -93,7 +95,13 @@ func runClient(
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
) error { ) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion()) defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required. // Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.

View File

@@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected, ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
} }
e.statusRecorder.ReplaceOfflinePeers(replacement) e.statusRecorder.ReplaceOfflinePeers(replacement)

View File

@@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
} }
conn.agent, err = ice.NewAgent(agentConfig) conn.agent, err = ice.NewAgent(agentConfig)
if err != nil { if err != nil {
return err return err
} }
@@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: conn.status, ConnStatus: conn.status,
Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.status, ConnStatus: conn.status,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@@ -465,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(), LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(), RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local), Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled, RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
} }
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
@@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.status, ConnStatus: conn.status,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {

View File

@@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer // State contains the latest state of a peer
type State struct { type State struct {
Mux *sync.RWMutex
IP string IP string
PubKey string PubKey string
FQDN string FQDN string
@@ -30,7 +31,38 @@ type State struct {
BytesRx int64 BytesRx int64
Latency time.Duration Latency time.Duration
RosenpassEnabled bool RosenpassEnabled bool
Routes map[string]struct{} routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
} }
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey, PubKey: peerPubKey,
ConnStatus: StatusDisconnected, ConnStatus: StatusDisconnected,
FQDN: fqdn, FQDN: fqdn,
Mux: new(sync.RWMutex),
} }
d.peerListChangedForNotification = true d.peerListChangedForNotification = true
return nil return nil
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP peerState.IP = receivedState.IP
} }
if receivedState.Routes != nil { if receivedState.GetRoutes() != nil {
peerState.Routes = receivedState.Routes peerState.SetRoutes(receivedState.GetRoutes())
} }
skipNotification := shouldSkipNotify(receivedState, peerState) skipNotification := shouldSkipNotify(receivedState, peerState)
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError) s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true return true
} }
return false return false
} }

View File

@@ -3,6 +3,7 @@ package peer
import ( import (
"errors" "errors"
"testing" "testing"
"sync"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -18,6 +19,7 @@ type routerPeerStatus struct {
connected bool connected bool
relayed bool relayed bool
direct bool direct bool
latency time.Duration
} }
type routesUpdate struct { type routesUpdate struct {
@@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
connected: peerStatus.ConnStatus == peer.StatusConnected, connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed, relayed: peerStatus.Relayed,
direct: peerStatus.Direct, direct: peerStatus.Direct,
latency: peerStatus.Latency,
} }
} }
return routePeerStatuses return routePeerStatuses
@@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
// * Non-relayed: Routes without relays are preferred. // * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored. // * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained. // * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := "" chosen := ""
chosenScore := 0 chosenScore := float64(0)
currScore := float64(0)
currID := "" currID := ""
if c.chosenRoute != nil { if c.chosenRoute != nil {
@@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
} }
for _, r := range c.routes { for _, r := range c.routes {
tempScore := 0 tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID] peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected { if !found || !peerStatus.connected {
continue continue
@@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if r.Metric < route.MaxMetric { if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric metricDiff := route.MaxMetric - r.Metric
tempScore = metricDiff * 10 tempScore = float64(metricDiff) * 10
} }
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
}
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed { if !peerStatus.relayed {
tempScore++ tempScore++
} }
@@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
tempScore++ tempScore++
} }
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) { if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID chosen = r.ID
chosenScore = tempScore chosenScore = tempScore
} }
@@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
chosen = r.ID chosen = r.ID
chosenScore = tempScore chosenScore = tempScore
} }
if r.ID == currID {
currScore = tempScore
}
} }
if chosen == "" { switch {
case chosen == "":
var peers []string var peers []string
for _, r := range c.routes { for _, r := range c.routes {
peers = append(peers, r.Peer) peers = append(peers, r.Peer)
} }
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
case chosen != currID:
} else if chosen != currID { if currScore != 0 && currScore < chosenScore+0.1 {
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) return currID
} else {
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
} }
return chosen return chosen
@@ -174,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err) return fmt.Errorf("get peer state: %v", err)
} }
delete(state.Routes, c.network.String()) state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil { if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }
@@ -246,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil { if err != nil {
log.Errorf("Failed to get peer state: %v", err) log.Errorf("Failed to get peer state: %v", err)
} else { } else {
if state.Routes == nil { state.AddRoute(c.network.String())
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
if err := c.statusRecorder.UpdatePeerState(state); err != nil { if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }

View File

@@ -3,6 +3,7 @@ package routemanager
import ( import (
"net/netip" "net/netip"
"testing" "testing"
"time"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
name string name string
statuses map[string]routerPeerStatus statuses map[string]routerPeerStatus
expectedRouteID string expectedRouteID string
currentRoute *route.Route currentRoute string
existingRoutes map[string]*route.Route existingRoutes map[string]*route.Route
}{ }{
{ {
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1", Peer: "peer1",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{ {
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1", Peer: "peer1",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{ {
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1", Peer: "peer1",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{ {
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1", Peer: "peer1",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "", expectedRouteID: "",
}, },
{ {
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2", Peer: "peer2",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{ {
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2", Peer: "peer2",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{ {
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2", Peer: "peer2",
}, },
}, },
currentRoute: nil, currentRoute: "",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{
name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 12 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork // create new clientNetwork
client := &clientNetwork{ client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"), network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes, routes: tc.existingRoutes,
chosenRoute: tc.currentRoute, chosenRoute: currentRoute,
} }
chosenRoute := client.getBestRouteFromStatuses(tc.statuses) chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil { if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err) log.Warnf("Failed cleaning up routing: %v", err)
} }
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
} }
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil { if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err) log.Errorf("Error cleaning up routing: %v", err)
} else { } else {
log.Info("Routing cleanup complete") log.Info("Routing cleanup complete")
} }
}
m.ctx = nil m.ctx = nil
} }
@@ -210,10 +219,12 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
} }
func isPrefixSupported(prefix netip.Prefix) bool { func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "windows", "darwin": case "linux", "windows", "darwin":
return true return true
} }
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported // If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management // we skip this prefix management

View File

@@ -8,6 +8,8 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -89,19 +91,38 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
} }
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, ok := netip.AddrFromSlice(preferredSrc) addr, err := ipToAddr(preferredSrc, intf)
if !ok { if err != nil {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
} }
return addr.Unmap(), intf, nil return addr.Unmap(), intf, nil
} }
addr, ok := netip.AddrFromSlice(gateway) addr, err := ipToAddr(gateway, intf)
if !ok { if err != nil {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
} }
return addr.Unmap(), intf, nil return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func existsInRouteTable(prefix netip.Prefix) (bool, error) {

View File

@@ -8,6 +8,7 @@ import (
"net/netip" "net/netip"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route" "golang.org/x/net/route"
) )
@@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
continue continue
} }
if len(m.Addrs) < 3 {
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
continue
}
addr, ok := toNetIPAddr(m.Addrs[0]) addr, ok := toNetIPAddr(m.Addrs[0])
if !ok { if !ok {
continue continue
} }
mask, ok := toNetIPMASK(m.Addrs[2]) cidr := 32
if mask := m.Addrs[2]; mask != nil {
cidr, ok = toCIDR(mask)
if !ok { if !ok {
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
continue continue
} }
cidr, _ := mask.Size() }
routePrefix := netip.PrefixFrom(addr, cidr) routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() { if routePrefix.IsValid() {
@@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func toNetIPAddr(a route.Addr) (netip.Addr, bool) { func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
switch t := a.(type) { switch t := a.(type) {
case *route.Inet4Addr: case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) return netip.AddrFrom4(t.IP), true
addr := netip.MustParseAddr(ip.String())
return addr, true
default: default:
return netip.Addr{}, false return netip.Addr{}, false
} }
} }
func toNetIPMASK(a route.Addr) (net.IPMask, bool) { func toCIDR(a route.Addr) (int, bool) {
switch t := a.(type) { switch t := a.(type) {
case *route.Inet4Addr: case *route.Inet4Addr:
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return mask, true cidr, _ := mask.Size()
return cidr, true
default: default:
return nil, false return 0, false
} }
} }

View File

@@ -8,7 +8,9 @@ import (
"net/netip" "net/netip"
"os/exec" "os/exec"
"strings" "strings"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -35,6 +37,10 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string)
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
inet := "-inet" inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
inet = "-inet6" inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails // Special case for IPv6 split default route, pointing to the wg interface fails
@@ -44,18 +50,40 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
} }
} }
args := []string{"-n", action, inet, prefix.String()} args := []string{"-n", action, inet, network}
if nexthop.IsValid() { if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String()) args = append(args, nexthop.Unmap().String())
} else if intf != "" { } else if intf != "" {
args = append(args, "-interface", intf) args = append(args, "-interface", intf)
} }
out, err := exec.Command("route", args...).CombinedOutput() if err := retryRouteCmd(args); err != nil {
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
} }
return nil return nil
} }
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
}
return nil
}

View File

@@ -5,8 +5,10 @@ package routemanager
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"regexp" "regexp"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -29,6 +31,42 @@ func init() {
}...) }...)
} }
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0"
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper() t.Helper()

View File

@@ -4,14 +4,14 @@ package routemanager
import ( import (
"bufio" "bufio"
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strconv"
"strings"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -32,19 +32,31 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables" rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting. // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
var ErrTableIDExists = errors.New("ID exists with different name") var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{} var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct { type ruleParams struct {
priority int
fwmark int fwmark int
tableID int tableID int
family int family int
priority int
invert bool invert bool
suppressPrefix int suppressPrefix int
description string description string
@@ -52,10 +64,10 @@ type ruleParams struct {
func getSetupRules() []ruleParams { func getSetupRules() []ruleParams {
return []ruleParams{ return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
} }
} }
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy { if isLegacy {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err) log.Errorf("Error adding routing table name: %v", err)
} }
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil { if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -123,11 +140,17 @@ func cleanupRouting() error {
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
} }
} }
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil() return result.ErrorOrNil()
} }
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf) return genericAddVPNRoute(prefix, intf)
} }
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
} }
func enableIPForwarding() error { func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath) _, err := setSysctl(ipv4ForwardingPath, 1, false)
if err != nil { return err
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
} }
// entryExists checks if the specified ID or name already exists in the rt_tables file // entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err) return fmt.Errorf("add routing rule: %w", err)
} }
@@ -446,47 +459,20 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil { if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err) return fmt.Errorf("remove routing rule: %w", err)
} }
return nil return nil
} }
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route. // addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() { if addr.IsValid() {
route.Gw = addr.AsSlice() route.Gw = addr.AsSlice()
if intf == "" {
intf = addr.Zone()
}
} }
if intf != "" { if intf != "" {
@@ -506,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil) _, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, cleanupRouting()) assert.NoError(t, cleanupRouting())

View File

@@ -63,7 +63,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return prefixList, nil return prefixList, nil
} }
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String() destinationPrefix := prefix.String()
psCmd := "New-NetRoute" psCmd := "New-NetRoute"
@@ -73,10 +73,20 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er
} }
script := fmt.Sprintf( script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix, intf, psCmd, addressFamily, destinationPrefix,
) )
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
if nexthop.IsValid() { if nexthop.IsValid() {
script = fmt.Sprintf( script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop, `%s -NextHop "%s"`, script, nexthop,
@@ -84,7 +94,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er
} }
out, err := exec.Command("powershell", "-Command", script).CombinedOutput() out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell add route: %s", string(out)) log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil { if err != nil {
return fmt.Errorf("PowerShell add route: %w", err) return fmt.Errorf("PowerShell add route: %w", err)
@@ -98,7 +108,7 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
out, err := exec.Command("route", args...).CombinedOutput() out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s output: %s", strings.Join(args, " "), out) log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil { if err != nil {
return fmt.Errorf("route add: %w", err) return fmt.Errorf("route add: %w", err)
} }
@@ -107,9 +117,15 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name // Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" { if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf) return addRoutePowershell(prefix, nexthop, intf, intfIdx)
} }
return addRouteCmd(prefix, nexthop, intf) return addRouteCmd(prefix, nexthop, intf)
} }
@@ -117,11 +133,12 @@ func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"delete", prefix.String()} args := []string{"delete", prefix.String()}
if nexthop.IsValid() { if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String()) args = append(args, nexthop.Unmap().String())
} }
out, err := exec.Command("route", args...).CombinedOutput() out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s output: %s", strings.Join(args, " "), out) log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil { if err != nil {
return fmt.Errorf("remove route: %w", err) return fmt.Errorf("remove route: %w", err)

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
} }
// Set the fwmark on the socket. // Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) err = nbnet.SetSocketOpt(fd)
if err != nil { if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err) return nil, fmt.Errorf("setting fwmark failed: %w", err)
} }

View File

@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx, BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx, BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled, RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes), Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency), Latency: durationpb.New(peerState.Latency),
} }
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)

View File

@@ -25,8 +25,6 @@ func Detect(ctx context.Context) string {
detectDigitalOcean, detectDigitalOcean,
detectGCP, detectGCP,
detectOracle, detectOracle,
detectIBMCloud,
detectSoftlayer,
detectVultr, detectVultr,
} }

View File

@@ -6,7 +6,7 @@ import (
) )
func detectGCP(ctx context.Context) string { func detectGCP(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil) req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
if err != nil { if err != nil {
return "" return ""
} }

View File

@@ -1,54 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectIBMCloud(ctx context.Context) string {
v1ResultChan := make(chan bool, 1)
v2ResultChan := make(chan bool, 1)
go func() {
v1ResultChan <- detectIBMSecure(ctx)
}()
go func() {
v2ResultChan <- detectIBM(ctx)
}()
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
if v1Result || v2Result {
return "IBM Cloud"
}
return ""
}
func detectIBMSecure(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func detectIBM(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View File

@@ -1,25 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectSoftlayer(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
// Since SoftLayer was acquired by IBM, we should return "IBM Cloud"
return "IBM Cloud"
}
return ""
}

2
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible

4
go.sum
View File

@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=

View File

@@ -10,8 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type wgKernelConfigurer struct { type wgKernelConfigurer struct {
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil { if err != nil {
return err return err
} }
fwmark := nbnet.NetbirdFwmark fwmark := getFwmark()
config := wgtypes.Config{ config := wgtypes.Config{
PrivateKey: &key, PrivateKey: &key,
ReplacePeers: true, ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
func getFwmark() int { func getFwmark() int {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark return nbnet.NetbirdFwmark
} }
return 0 return 0

View File

@@ -58,6 +58,7 @@ services:
command: [ command: [
"--port", "443", "--port", "443",
"--log-file", "console", "--log-file", "console",
"--log-level", "info",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"

View File

@@ -2,7 +2,7 @@ version: "3"
services: services:
#UI dashboard #UI dashboard
dashboard: dashboard:
image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
restart: unless-stopped restart: unless-stopped
#ports: #ports:
# - 80:80 # - 80:80

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
defer cancel() defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg) httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }

View File

@@ -278,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
return routes return routes
} }
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership // filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route var filteredRoutes []*route.Route
for _, r := range routes { for _, r := range routes {
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
} }
if user.Id != account.CreatedBy { if user.Role != UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
} }
for _, otherUser := range account.Users { for _, otherUser := range account.Users {
@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user // if domain already has a primary account, add regular user
if domainAcc != nil { if domainAcc != nil {
account = domainAcc account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId) account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
} }
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID) updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err) log.Errorf("failed to get account %s: %v", accountID, err)
@@ -1864,6 +1865,7 @@ func addAllGroup(account *Account) error {
ID: xid.New().String(), ID: xid.New().String(),
Name: "All", Name: "All",
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID) allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route) routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID) users[userID] = NewOwnerUser(userID, accountID)
dnsSettings := DNSSettings{ dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0), DisabledManagementGroups: make([]string, 0),
} }

View File

@@ -11,133 +11,134 @@ type Code struct {
Code string Code string
} }
// Existing consts must not be changed, as this will break the compatibility with the existing data
const ( const (
// PeerAddedByUser indicates that a user added a new peer to the system // PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota PeerAddedByUser Activity = 0
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key // PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey PeerAddedWithSetupKey Activity = 1
// UserJoined indicates that a new user joined the account // UserJoined indicates that a new user joined the account
UserJoined UserJoined Activity = 2
// UserInvited indicates that a new user was invited to join the account // UserInvited indicates that a new user was invited to join the account
UserInvited UserInvited Activity = 3
// AccountCreated indicates that a new account has been created // AccountCreated indicates that a new account has been created
AccountCreated AccountCreated Activity = 4
// PeerRemovedByUser indicates that a user removed a peer from the system // PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser PeerRemovedByUser Activity = 5
// RuleAdded indicates that a user added a new rule // RuleAdded indicates that a user added a new rule
RuleAdded RuleAdded Activity = 6
// RuleUpdated indicates that a user updated a rule // RuleUpdated indicates that a user updated a rule
RuleUpdated RuleUpdated Activity = 7
// RuleRemoved indicates that a user removed a rule // RuleRemoved indicates that a user removed a rule
RuleRemoved RuleRemoved Activity = 8
// PolicyAdded indicates that a user added a new policy // PolicyAdded indicates that a user added a new policy
PolicyAdded PolicyAdded Activity = 9
// PolicyUpdated indicates that a user updated a policy // PolicyUpdated indicates that a user updated a policy
PolicyUpdated PolicyUpdated Activity = 10
// PolicyRemoved indicates that a user removed a policy // PolicyRemoved indicates that a user removed a policy
PolicyRemoved PolicyRemoved Activity = 11
// SetupKeyCreated indicates that a user created a new setup key // SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated SetupKeyCreated Activity = 12
// SetupKeyUpdated indicates that a user updated a setup key // SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated SetupKeyUpdated Activity = 13
// SetupKeyRevoked indicates that a user revoked a setup key // SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked SetupKeyRevoked Activity = 14
// SetupKeyOverused indicates that setup key usage exhausted // SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused SetupKeyOverused Activity = 15
// GroupCreated indicates that a user created a group // GroupCreated indicates that a user created a group
GroupCreated GroupCreated Activity = 16
// GroupUpdated indicates that a user updated a group // GroupUpdated indicates that a user updated a group
GroupUpdated GroupUpdated Activity = 17
// GroupAddedToPeer indicates that a user added group to a peer // GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer GroupAddedToPeer Activity = 18
// GroupRemovedFromPeer indicates that a user removed peer group // GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer GroupRemovedFromPeer Activity = 19
// GroupAddedToUser indicates that a user added group to a user // GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser GroupAddedToUser Activity = 20
// GroupRemovedFromUser indicates that a user removed a group from a user // GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser GroupRemovedFromUser Activity = 21
// UserRoleUpdated indicates that a user changed the role of a user // UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated UserRoleUpdated Activity = 22
// GroupAddedToSetupKey indicates that a user added group to a setup key // GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey GroupAddedToSetupKey Activity = 23
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key // GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey GroupRemovedFromSetupKey Activity = 24
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups // GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups GroupAddedToDisabledManagementGroups Activity = 25
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups // GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups GroupRemovedFromDisabledManagementGroups Activity = 26
// RouteCreated indicates that a user created a route // RouteCreated indicates that a user created a route
RouteCreated RouteCreated Activity = 27
// RouteRemoved indicates that a user deleted a route // RouteRemoved indicates that a user deleted a route
RouteRemoved RouteRemoved Activity = 28
// RouteUpdated indicates that a user updated a route // RouteUpdated indicates that a user updated a route
RouteUpdated RouteUpdated Activity = 29
// PeerSSHEnabled indicates that a user enabled SSH server on a peer // PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled PeerSSHEnabled Activity = 30
// PeerSSHDisabled indicates that a user disabled SSH server on a peer // PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled PeerSSHDisabled Activity = 31
// PeerRenamed indicates that a user renamed a peer // PeerRenamed indicates that a user renamed a peer
PeerRenamed PeerRenamed Activity = 32
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer // PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled PeerLoginExpirationEnabled Activity = 33
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer // PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled PeerLoginExpirationDisabled Activity = 34
// NameserverGroupCreated indicates that a user created a nameservers group // NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated NameserverGroupCreated Activity = 35
// NameserverGroupDeleted indicates that a user deleted a nameservers group // NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted NameserverGroupDeleted Activity = 36
// NameserverGroupUpdated indicates that a user updated a nameservers group // NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated NameserverGroupUpdated Activity = 37
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account // AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled AccountPeerLoginExpirationEnabled Activity = 38
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account // AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled AccountPeerLoginExpirationDisabled Activity = 39
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account // AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated AccountPeerLoginExpirationDurationUpdated Activity = 40
// PersonalAccessTokenCreated indicates that a user created a personal access token // PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated PersonalAccessTokenCreated Activity = 41
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token // PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted PersonalAccessTokenDeleted Activity = 42
// ServiceUserCreated indicates that a user created a service user // ServiceUserCreated indicates that a user created a service user
ServiceUserCreated ServiceUserCreated Activity = 43
// ServiceUserDeleted indicates that a user deleted a service user // ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted ServiceUserDeleted Activity = 44
// UserBlocked indicates that a user blocked another user // UserBlocked indicates that a user blocked another user
UserBlocked UserBlocked Activity = 45
// UserUnblocked indicates that a user unblocked another user // UserUnblocked indicates that a user unblocked another user
UserUnblocked UserUnblocked Activity = 46
// UserDeleted indicates that a user deleted another user // UserDeleted indicates that a user deleted another user
UserDeleted UserDeleted Activity = 47
// GroupDeleted indicates that a user deleted group // GroupDeleted indicates that a user deleted group
GroupDeleted GroupDeleted Activity = 48
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer UserLoggedInPeer Activity = 49
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected // PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired PeerLoginExpired Activity = 50
// DashboardLogin indicates that the user logged in to the dashboard // DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin DashboardLogin Activity = 51
// IntegrationCreated indicates that the user created an integration // IntegrationCreated indicates that the user created an integration
IntegrationCreated IntegrationCreated Activity = 52
// IntegrationUpdated indicates that the user updated an integration // IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated IntegrationUpdated Activity = 53
// IntegrationDeleted indicates that the user deleted an integration // IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted IntegrationDeleted Activity = 54
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account // AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled AccountPeerApprovalEnabled Activity = 55
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account // AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled AccountPeerApprovalDisabled Activity = 56
// PeerApproved indicates that the peer has been approved // PeerApproved indicates that the peer has been approved
PeerApproved PeerApproved Activity = 57
// PeerApprovalRevoked indicates that the peer approval has been revoked // PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked PeerApprovalRevoked Activity = 58
// TransferredOwnerRole indicates that the user transferred the owner role of the account // TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole TransferredOwnerRole Activity = 59
// PostureCheckCreated indicates that the user created a posture check // PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated PostureCheckCreated Activity = 60
// PostureCheckUpdated indicates that the user updated a posture check // PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check // PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted PostureCheckDeleted Activity = 62
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) { func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
sr := func(v string) *string { return &v } sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v } br := func(v bool) *bool { return &v }

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID, Id: testDNSSettingsAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{ Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"), testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
}, },
DNSSettings: baseExistingDNSSettings, DNSSettings: baseExistingDNSSettings,
} }

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
}, },
} }
accountID := "test_account" accountID := "test_account"
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
events := generateEvents(accountID, adminUser.Id) events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...) handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{ return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user", "account_id")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Users: map[string]*server.User{ Users: map[string]*server.User{

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser, group) p := initGroupTestData(adminUser, group)
for _, tc := range tt { for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser) p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser) p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {

View File

@@ -12,6 +12,7 @@ import (
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
@@ -38,7 +39,7 @@ type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg, AuthCfg: authCfg,
} }
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil { if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err) return nil, fmt.Errorf("register integrations endpoints: %w", err)
} }

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID, Id: testNSGroupAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"), "test_user": server.NewAdminUser("test_user", "account_id"),
}, },
} }

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user", "account_id")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil return nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user", "account_id")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",

View File

@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil return accountPostureChecks, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user", "account_id")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Users: map[string]*server.User{ Users: map[string]*server.User{

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
}, },
}, },
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"), "test_user": server.NewAdminUser("test_user", "account_id"),
}, },
} }

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user", "account_id")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true) server.SetupKeyUnlimitedUsage, true)

View File

@@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil return nil
} }
// parseOktaUserToUserData parse okta user to UserData. // parseOktaUser parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) { func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct { var oktaUser struct {
Email string `json:"email"` Email string `json:"email"`

View File

@@ -706,7 +706,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
return nil return nil
} }
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface // UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
if am.UpdateIntegratedValidatorGroupsFunc != nil { if am.UpdateIntegratedValidatorGroupsFunc != nil {
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)

View File

@@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
} }
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval { if peerNotValid {
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: account.Network.Copy(),
} }
@@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
am.updateAccountPeers(account) am.updateAccountPeers(account)
} }
approvedPeersMap, err := am.GetValidatedPeers(account) validPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C: case <-ticker.C:
select { select {
case <-cancel: case <-cancel:
log.Debugf("scheduled job %s was canceled, stop timer", ID) log.Tracef("scheduled job %s was canceled, stop timer", ID)
ticker.Stop() ticker.Stop()
return return
default: default:
log.Debugf("time to do a scheduled job %s", ID) log.Tracef("time to do a scheduled job %s", ID)
} }
runIn, reschedule := job() runIn, reschedule := job()
if !reschedule { if !reschedule {
wm.mu.Lock() wm.mu.Lock()
defer wm.mu.Unlock() defer wm.mu.Unlock()
delete(wm.jobs, ID) delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID) log.Tracef("job %s is not scheduled to run again", ID)
ticker.Stop() ticker.Stop()
return return
} }
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn) ticker.Reset(runIn)
} }
case <-cancel: case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID) log.Tracef("job %s was canceled, stopping timer", ID)
ticker.Stop() ticker.Stop()
return return
} }

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock return unlock
} }
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error { func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now() start := time.Now()
for _, key := range account.SetupKeys { // operate over a fresh copy as we will modify its fields
account.SetupKeysG = append(account.SetupKeysG, *key) accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
} }
for id, peer := range account.Peers { accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id peer.ID = id
account.PeersG = append(account.PeersG, *peer) //we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
} }
for id, user := range account.Users { accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs { for id, pat := range user.PATs {
pat.ID = id pat.ID = id
user.PATsG = append(user.PATsG, *pat) user.PATsG = append(user.PATsG, *pat)
} }
account.UsersG = append(account.UsersG, *user) accCopy.UsersG = append(accCopy.UsersG, *user)
} }
for id, group := range account.Groups { accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id group.ID = id
account.GroupsG = append(account.GroupsG, *group) //we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
} }
for id, route := range account.Routes { accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id route.ID = id
account.RoutesG = append(account.RoutesG, *route) //we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
} }
for id, ns := range account.NameServerGroups { accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) //we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
} }
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = tx.Select(clause.Associations).Delete(account) result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = tx. result = tx.
Session(&gorm.Session{FullSaveAssociations: true}). Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account) Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
return nil
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
}) })
took := time.Since(start) took := time.Since(start)
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took) s.metrics.StoreMetrics().CountPersistenceDuration(took)
} }
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds()) log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err return err
} }
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqliteStore) DeleteAccount(account *Account) error { func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now() start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil { if result.Error != nil {

View File

@@ -2,7 +2,12 @@ package server
import ( import (
"fmt" "fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net" "net"
"net/netip"
"path/filepath" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
} }
} }
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) { func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account) err := store.SaveAccount(account)
require.NoError(t, err) require.NoError(t, err)
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t) store := newSqliteStore(t)
testUserID := "testuser" testUserID := "testuser"
user := NewAdminUser(testUserID) user := NewAdminUser(testUserID, "account_id")
user.PATs = map[string]*PersonalAccessToken{"testtoken": { user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken", ID: "testtoken",
Name: "test token", Name: "test token",
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account) return store.SaveAccount(account)
} }
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
} }
// NewUser creates a new user // NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
return &User{ return &User{
Id: id, Id: ID,
AccountID: accountID,
Role: role, Role: role,
IsServiceUser: isServiceUser, IsServiceUser: isServiceUser,
NonDeletable: nonDeletable, NonDeletable: nonDeletable,
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
} }
// NewRegularUser creates a new user with role UserRoleUser // NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User { func NewRegularUser(ID, accountID string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
} }
// NewAdminUser creates a new user with role UserRoleAdmin // NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User { func NewAdminUser(ID, accountID string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
} }
// NewOwnerUser creates a new user with role UserRoleOwner // NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User { func NewOwnerUser(ID, accountID string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
} }
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
} }
newUserID := uuid.New().String() newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
log.Debugf("New User: %v", newUser) log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1") account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2") account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
err := store.SaveAccount(account) err := store.SaveAccount(account)
if err != nil { if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID) delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) { func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID) user := NewAdminUser(mockUserID, mockAccountID)
assert.True(t, user.HasAdminPower()) assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID) user = NewRegularUser(mockUserID, mockAccountID)
assert.False(t, user.HasAdminPower()) assert.False(t, user.HasAdminPower())
} }
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
} }
// create other users // create other users
account.Users[regularUserID] = NewRegularUser(regularUserID) account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID) account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account) err = manager.Store.SaveAccount(account)
if err != nil { if err != nil {

View File

@@ -3,6 +3,8 @@ package grpc
import ( import (
"context" "context"
"net" "net"
"os/user"
"runtime"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -12,6 +14,20 @@ import (
func WithCustomDialer() grpc.DialOption { func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err) log.Errorf("Failed to dial: %s", err)

View File

@@ -1,10 +1,7 @@
package net package net
import ( import (
"fmt"
"net" "net"
log "github.com/sirupsen/logrus"
) )
// Dialer extends the standard net.Dialer with the ability to execute hooks before // Dialer extends the standard net.Dialer with the ability to execute hooks before
@@ -22,43 +19,3 @@ func NewDialer() *Dialer {
return dialer return dialer
} }
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

View File

@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection // DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver var resolver *net.Resolver
if d.Resolver != nil { if d.Resolver != nil {
resolver = d.Resolver resolver = d.Resolver
@@ -56,7 +60,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
connID := GenerateConnID() connID := GenerateConnID()
if dialerDialHooks != nil { if dialerDialHooks != nil {
if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err) log.Errorf("Failed to call dialer hooks: %v", err)
} }
} }
@@ -97,7 +101,7 @@ func (c *Conn) Close() error {
return err return err
} }
func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address) host, _, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return fmt.Errorf("split host and port: %w", err) return fmt.Errorf("split host and port: %w", err)
@@ -121,3 +125,51 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string,
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

15
util/net/dialer_mobile.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build android || ios
package net
import (
"net"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"sync" "sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn // ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks. // which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address) pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil { if err != nil {
return nil, fmt.Errorf("listen packet: %w", err) return nil, fmt.Errorf("listen packet: %w", err)
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn // ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks. // which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err) return nil, fmt.Errorf("listen UDP: %w", err)
@@ -156,7 +165,7 @@ func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
if err := packetConn.Close(); err != nil { if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err) log.Errorf("Failed to close connection: %v", err)
} }
return nil, fmt.Errorf("expected UDPConn, got different type") return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
} }
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil

View File

@@ -1,10 +1,16 @@
package net package net
import "github.com/google/uuid" import (
"os"
"github.com/google/uuid"
)
const ( const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard // NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00 NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
) )
// ConnectionID provides a globally unique identifier for network connections. // ConnectionID provides a globally unique identifier for network connections.
@@ -15,3 +21,7 @@ type ConnectionID string
func GenerateConnID() ConnectionID { func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString()) return ConnectionID(uuid.NewString())
} }
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error var setErr error
err := conn.Control(func(fd uintptr) { err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) setErr = SetSocketOpt(int(fd))
}) })
if err != nil { if err != nil {
return fmt.Errorf("control: %w", err) return fmt.Errorf("control: %w", err)
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil return nil
} }
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}