Compare commits

...

63 Commits

Author SHA1 Message Date
bcmmbaga
bf4767211a Merge branch 'refs/heads/feature/optimize_sqlite_save' into deploy/posture-check-sqlite 2024-04-18 11:05:06 +03:00
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
bcmmbaga
7745ed7eb0 Merge branch 'refs/heads/main' into add-process-posture-check 2024-04-17 16:37:29 +03:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
Viktor Liu
e3b76448f3 Fix ICE endpoint remote port in status command (#1851) 2024-04-16 14:01:59 +02:00
bcmmbaga
6bfd1b2886 fix merge conflicts 2024-04-15 16:18:41 +03:00
bcmmbaga
8aa32a2da5 Merge branch 'refs/heads/main' into add-process-posture-check
# Conflicts:
#	management/server/peer.go
2024-04-15 16:14:21 +03:00
Bethuel Mmbaga
c6ab215d9d Extend management to sync meta and posture checks with peer (#1727)
* Add method to retrieve peer's applied posture checks

* Add posture checks in server response and update proto messages

* Refactor

* Extends peer metadata synchronization through SyncRequest and propagate posture changes on syncResponse

* Remove account lock

* Pass system info on sync

* Fix tests

* Refactor

* resolve merge

* Evaluate process check on client (#1749)

* implement  server and client sync peer meta alongside mocks

* wip: add check file and process

* Add files to peer metadata for process check

* wip: update peer meta on first sync

* Add files to peer's metadata

* Evaluate process check using files from peer metadata

* Fix panic and append windows path to files

* Fix check network address and files equality

* Evaluate active process on darwin

* Evaluate active process on linux

* Skip processing processes if no paths are set

* Return network map on peer meta-sync and update account peer's

* Update client network map on meta sync

* Get system info with applied checks

* Add windows package

* Remove a network map from sync meta-response

* Update checks proto message

* Keep client checks state and sync meta on checks change

* Evaluate a running process

* skip build for android and ios

* skip check file and process for android and ios

* bump gopsutil version

* fix tests

* move process check to separate os file

* refactor

* evaluate info with checks on receiving management events

* skip meta-update for an old client with no meta-sync support

* Check if peer meta is empty without reflection
2024-04-15 16:00:57 +03:00
Viktor Liu
e0de86d6c9 Use fixed activity codes (#1846)
* Add duplicate constants check
2024-04-15 14:15:46 +02:00
Zoltan Papp
5204d07811 Pass integrated validator for API (#1814)
Pass integrated validator for API handler
2024-04-15 12:08:38 +02:00
Viktor Liu
5ea24ba56e Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) 2024-04-12 17:53:07 +02:00
Viktor Liu
d30cf8706a Allow disabling custom routing (#1840) 2024-04-12 16:53:11 +02:00
Viktor Liu
15a2feb723 Use fixed preference for rules (#1836) 2024-04-12 16:07:03 +02:00
Viktor Liu
91b2f9fc51 Use route active store (#1834) 2024-04-12 15:22:40 +02:00
Carlos Hernandez
76702c8a09 Add safe read/write to route map (#1760) 2024-04-11 22:12:23 +02:00
Viktor Liu
061f673a4f Don't use the custom dialer as non-root (#1823) 2024-04-11 15:29:03 +02:00
Zoltan Papp
9505805313 Rename variable (#1829) 2024-04-11 14:08:03 +02:00
Maycon Santos
704c67dec8 Allow owners that did not create the account to delete it (#1825)
Sometimes the Owner role will be passed to new users, and they need to be able to delete the account
2024-04-11 10:02:51 +02:00
bcmmbaga
36582d13aa Merge branch 'refs/heads/main' into add-process-posture-check 2024-04-10 17:58:46 +03:00
pascal-fischer
3ed2f08f3c Add latency based routing (#1732)
Now that we have the latency between peers available we can use this data to consider when choosing the best route. This way the route with the routing peer with the lower latency will be preferred over others with the same target network.
2024-04-09 21:20:02 +02:00
Maycon Santos
4c83408f27 Add log-level to the management's docker service command (#1820) 2024-04-09 21:00:43 +02:00
Viktor Liu
90bd39c740 Log panics (#1818) 2024-04-09 20:27:27 +02:00
Maycon Santos
dd0cf41147 Auto restart Windows agent daemon service (#1819)
This enables auto restart of the windows agent daemon service on event of failure
2024-04-09 20:10:59 +02:00
pascal-fischer
22b2caffc6 Remove dns based cloud detection (#1812)
* remove dns based cloud checks

* remove dns based cloud checks
2024-04-09 19:01:31 +02:00
Viktor Liu
c1f66d1354 Retry macOS route command (#1817) 2024-04-09 15:27:19 +02:00
Viktor Liu
ac0fe6025b Fix routing issues with MacOS (#1815)
* Handle zones properly

* Use host routes for single IPs 

* Add GOOS and GOARCH to startup log

* Log powershell command
2024-04-09 13:25:14 +02:00
verytrap
c28657710a Fix function names in comments (#1816)
Signed-off-by: verytrap <wangqiuyue@outlook.com>
2024-04-09 13:18:38 +02:00
Maycon Santos
3875c29f6b Revert "Rollback new routing functionality (#1805)" (#1813)
This reverts commit 9f32ccd453.
2024-04-08 18:56:52 +02:00
Viktor Liu
9f32ccd453 Rollback new routing functionality (#1805) 2024-04-05 20:38:49 +02:00
trax
1d1d057e7d Change the dashboard image pull from wiretrustee to netbirdio (#1804) 2024-04-05 13:51:28 +02:00
bcmmbaga
2727680123 Merge branch 'main' into add-process-posture-check 2024-03-21 21:30:40 +03:00
bcmmbaga
9dcaa51b68 Merge branch 'main' into add-process-posture-check 2024-03-18 18:41:38 +03:00
Bethuel Mmbaga
180f5a122e Refactor posture check validations (#1705)
* Add posture checks validation

* Refactor code to incorporate posture checks validation directly into management.

* Add posture checks validation for geolocation, OS version, network, process, and NB-version

* Fix tests
2024-03-14 20:16:50 +00:00
bcmmbaga
90ab2f7c89 Fix linters 2024-03-14 16:06:50 +03:00
bcmmbaga
4ab993c933 Fix tests 2024-03-14 15:52:15 +03:00
bcmmbaga
1a5d59be1d Refactor 2024-03-14 14:35:21 +03:00
bcmmbaga
9db450d599 Add single Unix/Windows path check in process tests 2024-03-14 14:32:55 +03:00
bcmmbaga
cc60df7805 Allow set of single unix or windows path check 2024-03-14 14:32:40 +03:00
bcmmbaga
60f9f08ecb fix tests 2024-03-13 11:02:47 +03:00
bcmmbaga
41348bb39b Add process validation for peer metadata 2024-03-12 19:24:08 +03:00
bcmmbaga
e66e39cc70 Extend peer metadata with processes 2024-03-12 19:23:57 +03:00
bcmmbaga
9f41a1f20f add process posture check to posture checks handlers 2024-03-12 15:20:00 +03:00
bcmmbaga
5f0eec0add wip: add process check posture 2024-03-12 15:19:22 +03:00
91 changed files with 3250 additions and 1178 deletions

View File

@@ -33,6 +33,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:

View File

@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"runtime"
"runtime/debug"
"strings"
"time"
@@ -93,7 +95,13 @@ func runClient(
relayProbe *Probe,
wgProbe *Probe,
) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
@@ -229,7 +237,10 @@ func runClient(
return wrapErr(err)
}
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
checks := loginResp.GetChecks()
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig,
mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -27,6 +28,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
@@ -138,6 +140,9 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
}
// Peer is an instance of the Connection Peer
@@ -155,6 +160,7 @@ func NewEngine(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
ctx,
@@ -168,6 +174,7 @@ func NewEngine(
nil,
nil,
nil,
checks,
)
}
@@ -184,6 +191,7 @@ func NewEngineWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
ctx: ctx,
@@ -204,6 +212,7 @@ func NewEngineWithProbes(
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
checks: checks,
}
}
@@ -486,6 +495,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
@@ -493,7 +506,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
}
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
@@ -583,7 +616,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
err := e.mgmClient.Sync(e.handleSync)
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
err = e.mgmClient.Sync(info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -794,6 +833,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
@@ -1150,7 +1190,8 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap()
info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err
}
@@ -1328,3 +1369,10 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -76,7 +76,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -210,7 +210,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -391,7 +391,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -406,7 +406,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -564,7 +564,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -733,7 +733,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -1002,7 +1002,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
}
func startSignal() (*grpc.Server, string, error) {

View File

@@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -465,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
@@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -30,7 +31,38 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
Routes map[string]struct{}
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
@@ -18,6 +19,7 @@ type routerPeerStatus struct {
connected bool
relayed bool
direct bool
latency time.Duration
}
type routesUpdate struct {
@@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
@@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := ""
chosenScore := 0
chosenScore := float64(0)
currScore := float64(0)
currID := ""
if c.chosenRoute != nil {
@@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
}
for _, r := range c.routes {
tempScore := 0
tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
@@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = metricDiff * 10
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
}
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
tempScore++
}
@@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenScore = tempScore
}
@@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
chosen = r.ID
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
if chosen == "" {
switch {
case chosen == "":
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
} else if chosen != currID {
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
case chosen != currID:
if currScore != 0 && currScore < chosenScore+0.1 {
return currID
} else {
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
}
return chosen
@@ -174,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err)
}
delete(state.Routes, c.network.String())
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
@@ -246,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/route"
)
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
name string
statuses map[string]routerPeerStatus
expectedRouteID string
currentRoute *route.Route
currentRoute string
existingRoutes map[string]*route.Route
}{
{
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "",
},
{
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 12 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: tc.currentRoute,
chosenRoute: currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
@@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@@ -8,6 +8,8 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
@@ -85,23 +87,42 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, ok := netip.AddrFromSlice(preferredSrc)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr.Unmap(), intf, nil
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
)
@@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
continue
}
if len(m.Addrs) < 3 {
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
continue
}
addr, ok := toNetIPAddr(m.Addrs[0])
if !ok {
continue
}
mask, ok := toNetIPMASK(m.Addrs[2])
if !ok {
continue
cidr := 32
if mask := m.Addrs[2]; mask != nil {
cidr, ok = toCIDR(mask)
if !ok {
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
continue
}
}
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() {
@@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
addr := netip.MustParseAddr(ip.String())
return addr, true
return netip.AddrFrom4(t.IP), true
default:
return netip.Addr{}, false
}
}
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
func toCIDR(a route.Addr) (int, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return mask, true
cidr, _ := mask.Size()
return cidr, true
default:
return nil, false
return 0, false
}
}

View File

@@ -8,7 +8,9 @@ import (
"net/netip"
"os/exec"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -35,6 +37,10 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string)
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() {
inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails
@@ -44,18 +50,40 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
}
}
args := []string{"-n", action, inet, prefix.String()}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != "" {
args = append(args, "-interface", intf)
}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
}
return nil
}

View File

@@ -5,8 +5,10 @@ package routemanager
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -29,6 +31,42 @@ func init() {
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0"
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()

View File

@@ -4,14 +4,14 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -32,19 +32,31 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
@@ -52,10 +64,10 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -123,11 +140,17 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
}
func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -446,47 +459,20 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
route.Gw = addr.AsSlice()
if intf == "" {
intf = addr.Zone()
}
}
if intf != "" {
@@ -506,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil)
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())

View File

@@ -63,7 +63,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return prefixList, nil
}
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String()
psCmd := "New-NetRoute"
@@ -73,10 +73,20 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`,
psCmd, addressFamily, destinationPrefix, intf,
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
if nexthop.IsValid() {
script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop,
@@ -84,7 +94,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er
}
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell add route: %s", string(out))
log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
@@ -98,7 +108,7 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s output: %s", strings.Join(args, " "), out)
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
@@ -107,9 +117,15 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" {
return addRoutePowershell(prefix, nexthop, intf)
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf)
}
@@ -117,11 +133,12 @@ func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s output: %s", strings.Join(args, " "), out)
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
}
// Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}

View File

@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes),
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)

View File

@@ -25,8 +25,6 @@ func Detect(ctx context.Context) string {
detectDigitalOcean,
detectGCP,
detectOracle,
detectIBMCloud,
detectSoftlayer,
detectVultr,
}

View File

@@ -6,7 +6,7 @@ import (
)
func detectGCP(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil)
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
if err != nil {
return ""
}

View File

@@ -1,54 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectIBMCloud(ctx context.Context) string {
v1ResultChan := make(chan bool, 1)
v2ResultChan := make(chan bool, 1)
go func() {
v1ResultChan <- detectIBMSecure(ctx)
}()
go func() {
v2ResultChan <- detectIBM(ctx)
}()
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
if v1Result || v2Result {
return "IBM Cloud"
}
return ""
}
func detectIBMSecure(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func detectIBM(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View File

@@ -1,25 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectSoftlayer(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
// Since SoftLayer was acquired by IBM, we should return "IBM Cloud"
return "IBM Cloud"
}
return ""
}

View File

@@ -8,6 +8,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/version"
)
@@ -30,6 +31,12 @@ type Environment struct {
Platform string
}
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct {
@@ -48,6 +55,7 @@ type Info struct {
SystemProductName string
SystemManufacturer string
Environment Environment
Files []File
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@@ -129,3 +137,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
if err != nil {
return nil, err
}
info := GetInfo(ctx)
info.Files = files
return info, nil
}

View File

@@ -36,6 +36,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
func uname() []string {
res := run("/system/bin/uname", "-a")
return strings.Split(res, " ")

View File

@@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
// extractOsVersion extracts operating system version from context or returns the default
func extractOsVersion(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsVersionCtxKey).(string)

58
client/system/process.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build windows || (linux && !android) || (darwin && !ios)
package system
import (
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processes, err := process.Processes()
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, p := range processes {
path, _ := p.Exe()
if path != "" {
processMap[path] = true
}
}
uniqueProcesses := make([]string, 0, len(processMap))
for p := range processMap {
uniqueProcesses = append(uniqueProcesses, p)
}
return uniqueProcesses, nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
if err != nil {
return nil, err
}
for i, path := range paths {
file := File{Path: path}
_, err := os.Stat(path)
file.Exist = !os.IsNotExist(err)
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
files[i] = file
}
return files, nil
}

16
go.mod
View File

@@ -22,7 +22,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.18.0
golang.org/x/sys v0.16.0
golang.org/x/sys v0.18.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -44,7 +44,7 @@ require (
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.9
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -70,10 +70,11 @@ require (
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/things-go/go-socks5 v0.0.4
github.com/yusufpapurcu/wmi v1.2.3
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.0.2
go.opentelemetry.io/otel v1.11.1
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
@@ -131,6 +132,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
@@ -142,12 +144,16 @@ require (
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect

36
go.sum
View File

@@ -247,9 +247,11 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
@@ -348,6 +350,8 @@ github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdA
github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU=
github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ=
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
@@ -383,8 +387,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@@ -449,6 +453,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
@@ -485,6 +491,12 @@ github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shirou/gopsutil/v3 v3.24.3 h1:eoUGJSmdfLzJ3mxIhmOAhgKEKgQkeOwKpz1NbhVnuPE=
github.com/shirou/gopsutil/v3 v3.24.3/go.mod h1:JpND7O217xa72ewWz9zN2eIIkPWsDN/3pl0H8Qt0uwg=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
@@ -514,6 +526,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -524,10 +537,17 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
@@ -544,8 +564,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
github.com/yuin/goldmark v1.3.8/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc=
github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
@@ -740,6 +760,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -761,8 +782,9 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View File

@@ -10,8 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
type wgKernelConfigurer struct {
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil {
return err
}
fwmark := nbnet.NetbirdFwmark
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -58,6 +58,7 @@ services:
command: [
"--port", "443",
"--log-file", "console",
"--log-level", "info",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"

View File

@@ -2,7 +2,7 @@ version: "3"
services:
#UI dashboard
dashboard:
image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
restart: unless-stopped
#ports:
# - 80:80

View File

@@ -3,19 +3,21 @@ package client
import (
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Client interface {
io.Closer
Sync(msgHandler func(msg *proto.SyncResponse) error) error
Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap() (*proto.NetworkMap, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
}

View File

@@ -255,7 +255,7 @@ func TestClient_Sync(t *testing.T) {
ch := make(chan *mgmtProto.SyncResponse, 1)
go func() {
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
err = client.Sync(info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})

View File

@@ -113,7 +113,7 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
func (c *GrpcClient) Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(c.ctx)
operation := func() error {
@@ -135,7 +135,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -177,7 +177,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
@@ -186,7 +186,7 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
@@ -219,8 +219,8 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
@@ -430,6 +430,35 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
return flowInfoResp, nil
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return fmt.Errorf("no connection to management")
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return err
}
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: syncMetaReq,
})
return err
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
@@ -463,6 +492,15 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
})
}
files := make([]*proto.File, 0, len(info.Files))
for _, file := range info.Files {
files = append(files, &proto.File{
Path: file.Path,
Exist: file.Exist,
ProcessIsRunning: file.ProcessIsRunning,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
@@ -482,5 +520,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
Files: files,
}
}

View File

@@ -9,12 +9,13 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
SyncFunc func(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
}
func (m *MockClient) IsHealthy() bool {
@@ -28,11 +29,11 @@ func (m *MockClient) Close() error {
return m.CloseFunc()
}
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
func (m *MockClient) Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(msgHandler)
return m.SyncFunc(sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
@@ -71,6 +72,13 @@ func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
}

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -38,6 +38,12 @@ service ManagementService {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {
@@ -50,7 +56,10 @@ message EncryptedMessage {
int32 version = 3;
}
message SyncRequest {}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs)
message SyncResponse {
@@ -69,6 +78,14 @@ message SyncResponse {
bool remotePeersIsEmpty = 4;
NetworkMap NetworkMap = 5;
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
}
message SyncMetaRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
message LoginRequest {
@@ -82,6 +99,7 @@ message LoginRequest {
PeerKeys peerKeys = 4;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
@@ -100,6 +118,16 @@ message Environment {
string platform = 2;
}
// File represents a file on the system.
message File {
// path is the path to the file.
string path = 1;
// exist indicate whether the file exists.
bool exist = 2;
// processIsRunning indicates whether the file is a running process or not.
bool processIsRunning = 3;
}
// PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta {
string hostname = 1;
@@ -117,6 +145,7 @@ message PeerSystemMeta {
string sysProductName = 13;
string sysManufacturer = 14;
Environment environment = 15;
repeated File files = 16;
}
message LoginResponse {
@@ -124,6 +153,8 @@ message LoginResponse {
WiretrusteeConfig wiretrusteeConfig = 1;
// Peer local config
PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client
repeated Checks Checks = 3;
}
message ServerKeyResponse {
@@ -371,3 +402,7 @@ message NetworkAddress {
string netIP = 1;
string mac = 2;
}
message Checks {
repeated string Files= 1;
}

View File

@@ -43,6 +43,11 @@ type ManagementServiceClient interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
@@ -130,6 +135,15 @@ func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context,
return out, nil
}
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
@@ -159,6 +173,11 @@ type ManagementServiceServer interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
@@ -184,6 +203,9 @@ func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.C
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -308,6 +330,24 @@ func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx co
return interceptor(ctx, in, info, handler)
}
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/SyncMeta",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -335,6 +375,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetPKCEAuthorizationFlow",
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
},
{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -114,6 +114,7 @@ type AccountManager interface {
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
@@ -242,19 +243,19 @@ type UserPermissions struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
Permissions UserPermissions `json:"permissions"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -278,7 +279,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
return routes
}
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -1120,7 +1121,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
}
if user.Id != account.CreatedBy {
if user.Role != UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
for _, otherUser := range account.Users {
@@ -1473,7 +1474,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
@@ -1849,6 +1850,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
}
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)
@@ -1861,9 +1863,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1907,7 +1910,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
users[userID] = NewOwnerUser(userID, accountID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}

View File

@@ -11,133 +11,134 @@ type Code struct {
Code string
}
// Existing consts must not be changed, as this will break the compatibility with the existing data
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota
PeerAddedByUser Activity = 0
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey
PeerAddedWithSetupKey Activity = 1
// UserJoined indicates that a new user joined the account
UserJoined
UserJoined Activity = 2
// UserInvited indicates that a new user was invited to join the account
UserInvited
UserInvited Activity = 3
// AccountCreated indicates that a new account has been created
AccountCreated
AccountCreated Activity = 4
// PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser
PeerRemovedByUser Activity = 5
// RuleAdded indicates that a user added a new rule
RuleAdded
RuleAdded Activity = 6
// RuleUpdated indicates that a user updated a rule
RuleUpdated
RuleUpdated Activity = 7
// RuleRemoved indicates that a user removed a rule
RuleRemoved
RuleRemoved Activity = 8
// PolicyAdded indicates that a user added a new policy
PolicyAdded
PolicyAdded Activity = 9
// PolicyUpdated indicates that a user updated a policy
PolicyUpdated
PolicyUpdated Activity = 10
// PolicyRemoved indicates that a user removed a policy
PolicyRemoved
PolicyRemoved Activity = 11
// SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated
SetupKeyCreated Activity = 12
// SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated
SetupKeyUpdated Activity = 13
// SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked
SetupKeyRevoked Activity = 14
// SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused
SetupKeyOverused Activity = 15
// GroupCreated indicates that a user created a group
GroupCreated
GroupCreated Activity = 16
// GroupUpdated indicates that a user updated a group
GroupUpdated
GroupUpdated Activity = 17
// GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer
GroupAddedToPeer Activity = 18
// GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer
GroupRemovedFromPeer Activity = 19
// GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser
GroupAddedToUser Activity = 20
// GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser
GroupRemovedFromUser Activity = 21
// UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated
UserRoleUpdated Activity = 22
// GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey
GroupAddedToSetupKey Activity = 23
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey
GroupRemovedFromSetupKey Activity = 24
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups
GroupAddedToDisabledManagementGroups Activity = 25
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups
GroupRemovedFromDisabledManagementGroups Activity = 26
// RouteCreated indicates that a user created a route
RouteCreated
RouteCreated Activity = 27
// RouteRemoved indicates that a user deleted a route
RouteRemoved
RouteRemoved Activity = 28
// RouteUpdated indicates that a user updated a route
RouteUpdated
RouteUpdated Activity = 29
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled
PeerSSHEnabled Activity = 30
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled
PeerSSHDisabled Activity = 31
// PeerRenamed indicates that a user renamed a peer
PeerRenamed
PeerRenamed Activity = 32
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled
PeerLoginExpirationEnabled Activity = 33
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled
PeerLoginExpirationDisabled Activity = 34
// NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated
NameserverGroupCreated Activity = 35
// NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted
NameserverGroupDeleted Activity = 36
// NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated
NameserverGroupUpdated Activity = 37
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled
AccountPeerLoginExpirationEnabled Activity = 38
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled
AccountPeerLoginExpirationDisabled Activity = 39
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated
AccountPeerLoginExpirationDurationUpdated Activity = 40
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated
PersonalAccessTokenCreated Activity = 41
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted
PersonalAccessTokenDeleted Activity = 42
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
ServiceUserCreated Activity = 43
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
ServiceUserDeleted Activity = 44
// UserBlocked indicates that a user blocked another user
UserBlocked
UserBlocked Activity = 45
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
UserUnblocked Activity = 46
// UserDeleted indicates that a user deleted another user
UserDeleted
UserDeleted Activity = 47
// GroupDeleted indicates that a user deleted group
GroupDeleted
GroupDeleted Activity = 48
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
UserLoggedInPeer Activity = 49
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
PeerLoginExpired Activity = 50
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
DashboardLogin Activity = 51
// IntegrationCreated indicates that the user created an integration
IntegrationCreated
IntegrationCreated Activity = 52
// IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated
IntegrationUpdated Activity = 53
// IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted
IntegrationDeleted Activity = 54
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled
AccountPeerApprovalEnabled Activity = 55
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled
AccountPeerApprovalDisabled Activity = 56
// PeerApproved indicates that the peer has been approved
PeerApproved
PeerApproved Activity = 57
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked
PeerApprovalRevoked Activity = 58
// TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole
TransferredOwnerRole Activity = 59
// PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated
PostureCheckCreated Activity = 60
// PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted
PostureCheckDeleted Activity = 62
)
var activityMap = map[Activity]Code{

View File

@@ -134,7 +134,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{
WireGuardPubKey: peerKey.String(),
Meta: extractPeerMeta(syncReq.GetMeta()),
})
if err != nil {
return mapError(err)
}
@@ -255,14 +262,18 @@ func mapError(err error) error {
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
osVersion := loginReq.GetMeta().GetOSVersion()
if osVersion == "" {
osVersion = loginReq.GetMeta().GetCore()
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses()))
for _, addr := range loginReq.GetMeta().GetNetworkAddresses() {
osVersion := meta.GetOSVersion()
if osVersion == "" {
osVersion = meta.GetCore()
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses()))
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
@@ -274,24 +285,34 @@ func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
})
}
files := make([]nbpeer.File, 0, len(meta.GetFiles()))
for _, file := range meta.GetFiles() {
files = append(files, nbpeer.File{
Path: file.GetPath(),
Exist: file.GetExist(),
ProcessIsRunning: file.GetProcessIsRunning(),
})
}
return nbpeer.PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(),
Platform: loginReq.GetMeta().GetPlatform(),
OS: loginReq.GetMeta().GetOS(),
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: osVersion,
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
UIVersion: loginReq.GetMeta().GetUiVersion(),
KernelVersion: loginReq.GetMeta().GetKernelVersion(),
WtVersion: meta.GetWiretrusteeVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
NetworkAddresses: networkAddresses,
SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(),
SystemProductName: loginReq.GetMeta().GetSysProductName(),
SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(),
Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(),
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Files: files,
}
}
@@ -366,7 +387,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq),
Meta: extractPeerMeta(loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
@@ -386,6 +407,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toPeerChecks(s.accountManager, peerKey.String()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
@@ -482,7 +504,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -513,6 +535,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toPeerChecks(accountManager, peer.Key),
}
}
@@ -531,7 +554,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -648,3 +671,62 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
Body: encryptedResp,
}, nil
}
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq)
if err != nil {
return nil, err
}
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
return nil, msg
}
_, _, err = s.accountManager.SyncPeer(PeerSync{
WireGuardPubKey: peerKey.String(),
Meta: extractPeerMeta(syncMetaReq.GetMeta()),
UpdateAccountPeers: true,
})
if err != nil {
return nil, mapError(err)
}
return &proto.Empty{}, nil
}
// toPeerChecks returns posture checks for the peer that needs to be evaluated on the client side.
func toPeerChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
for _, postureCheck := range postureChecks {
protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil {
for _, process := range check.Processes {
if process.Path != "" {
protoCheck.Files = append(protoCheck.Files, process.Path)
}
if process.WindowsPath != "" {
protoCheck.Files = append(protoCheck.Files, process.WindowsPath)
}
}
}
protoChecks = append(protoChecks, protoCheck)
}
return protoChecks
}

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }

View File

@@ -812,6 +812,8 @@ components:
$ref: '#/components/schemas/GeoLocationCheck'
peer_network_range_check:
$ref: '#/components/schemas/PeerNetworkRangeCheck'
process_check:
$ref: '#/components/schemas/ProcessCheck'
NBVersionCheck:
description: Posture check for the version of NetBird
type: object
@@ -900,6 +902,28 @@ components:
required:
- ranges
- action
ProcessCheck:
description: Posture Check for binaries exist and are running in the peers system
type: object
properties:
processes:
type: array
items:
$ref: '#/components/schemas/Process'
required:
- processes
Process:
description: Describes the operational activity within a peer's system.
type: object
properties:
path:
description: Path to the process executable file in a Unix-like operating system
type: string
example: "/usr/local/bin/netbird"
windows_path:
description: Path to the process executable file in a Windows operating system
type: string
example: "C:\ProgramData\NetBird\netbird.exe"
Location:
description: Describe geographical location information
type: object

View File

@@ -225,6 +225,9 @@ type Checks struct {
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
// ProcessCheck Posture Check for binaries exist and are running in the peers system
ProcessCheck *ProcessCheck `json:"process_check,omitempty"`
}
// City Describe city geographical location information
@@ -940,6 +943,20 @@ type PostureCheckUpdate struct {
Name string `json:"name"`
}
// Process Describes the operational activity within a peer's system.
type Process struct {
// Path Path to the process executable file in a Unix-like operating system
Path *string `json:"path,omitempty"`
// WindowsPath Path to the process executable file in a Windows operating system
WindowsPath *string `json:"windows_path,omitempty"`
}
// ProcessCheck Posture Check for binaries exist and are running in the peers system
type ProcessCheck struct {
Processes []Process `json:"processes"`
}
// Route defines model for Route.
type Route struct {
// Description Route description

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"),
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
},
DNSSettings: baseExistingDNSSettings,
}

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -2,6 +2,7 @@ package http
import (
"net/http"
"regexp"
"github.com/gorilla/mux"
@@ -13,6 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct {
accountManager server.AccountManager
@@ -73,8 +78,8 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {

View File

@@ -12,6 +12,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -38,7 +39,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -4,8 +4,6 @@ import (
"encoding/json"
"net/http"
"net/netip"
"regexp"
"slices"
"github.com/gorilla/mux"
"github.com/rs/xid"
@@ -19,10 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
@@ -165,19 +159,16 @@ func (p *PostureChecksHandler) savePostureChecks(
user *server.User,
postureChecksID string,
) {
var (
err error
req api.PostureCheckUpdate
)
var req api.PostureCheckUpdate
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := validatePostureChecksUpdate(req)
if err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
@@ -206,8 +197,8 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
@@ -221,6 +212,10 @@ func (p *PostureChecksHandler) savePostureChecks(
}
}
if processCheck := req.Checks.ProcessCheck; processCheck != nil {
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
}
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
util.WriteError(err, w)
return
@@ -229,72 +224,6 @@ func (p *PostureChecksHandler) savePostureChecks(
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
if req.Name == "" {
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
}
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
}
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return status.Errorf(status.InvalidArgument,
"minimum version for at least one OS in the OS version check shouldn't be empty")
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if geoLocationCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
}
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
}
if len(geoLocationCheck.Locations) == 0 {
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
}
for _, loc := range geoLocationCheck.Locations {
if loc.CountryCode == "" {
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
}
}
}
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
}
if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
}
}
return nil
}
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks
@@ -322,6 +251,10 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
}
if postureChecks.Checks.ProcessCheck != nil {
checks.ProcessCheck = toProcessCheckResponse(postureChecks.Checks.ProcessCheck)
}
return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
@@ -332,11 +265,10 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
for i, loc := range geoLocationCheck.Locations {
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
cityName = &geoLocationCheck.Locations[i].CityName
}
locations = append(locations, api.Location{
CityName: cityName,
@@ -396,3 +328,36 @@ func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNet
Action: string(check.Action),
}, nil
}
func toProcessCheckResponse(check *posture.ProcessCheck) *api.ProcessCheck {
processes := make([]api.Process, 0, len(check.Processes))
for i := range check.Processes {
processes = append(processes, api.Process{
Path: &check.Processes[i].Path,
WindowsPath: &check.Processes[i].WindowsPath,
})
}
return &api.ProcessCheck{
Processes: processes,
}
}
func toProcessCheck(check *api.ProcessCheck) *posture.ProcessCheck {
processes := make([]posture.Process, 0, len(check.Processes))
for _, process := range check.Processes {
var p posture.Process
if process.Path != nil {
p.Path = *process.Path
}
if process.WindowsPath != nil {
p.WindowsPath = *process.WindowsPath
}
processes = append(processes, p)
}
return &posture.ProcessCheck{
Processes: processes,
}
}

View File

@@ -43,6 +43,11 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
}
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
@@ -62,7 +67,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
@@ -433,6 +438,43 @@ func TestPostureCheckUpdate(t *testing.T) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks Process Check",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"process_check": {
"processes": [
{
"path": "/usr/local/bin/netbird",
"windows_path": "C:\\ProgramData\\NetBird\\netbird.exe"
}
]
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
ProcessCheck: &api.ProcessCheck{
Processes: []api.Process{
{
Path: str("/usr/local/bin/netbird"),
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
},
},
},
},
},
},
{
name: "Create Posture Checks Invalid Check",
requestType: http.MethodPost,
@@ -446,7 +488,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -461,7 +503,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -475,7 +517,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -489,7 +531,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"geo_location_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -663,11 +705,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
@@ -682,7 +721,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -697,7 +736,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -711,7 +750,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -841,100 +880,3 @@ func TestPostureCheckUpdate(t *testing.T) {
})
}
}
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
// empty name
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
assert.Error(t, err)
// empty checks
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
assert.Error(t, err)
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
assert.Error(t, err)
// not valid NbVersionCheck
nbVersionCheck := api.NBVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.Error(t, err)
// valid NbVersionCheck
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.NoError(t, err)
// not valid OsVersionCheck
osVersionCheck := api.OSVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid peer network range check
peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8",
},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
}

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)

View File

@@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil
}
// parseOktaUserToUserData parse okta user to UserData.
// parseOktaUser parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct {
Email string `json:"email"`

View File

@@ -134,7 +134,8 @@ func Test_SyncProtocol(t *testing.T) {
// take the first registered peer as a base for the test. Total four.
key := *peers[0]
message, err := encryption.EncryptMessage(*serverKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, key, syncReq)
if err != nil {
t.Fatal(err)
return

View File

@@ -93,7 +93,8 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
@@ -143,7 +144,7 @@ var _ = Describe("Management service", func() {
loginPeerWithValidSetupKey(serverPubKey, key1, client)
loginPeerWithValidSetupKey(serverPubKey, key2, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -176,7 +177,7 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -329,7 +330,7 @@ var _ = Describe("Management service", func() {
var clients []mgmtProto.ManagementService_SyncClient
for _, peer := range peers {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
Expect(err).NotTo(HaveOccurred())
@@ -394,7 +395,8 @@ var _ = Describe("Management service", func() {
defer GinkgoRecover()
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
// open stream

View File

@@ -80,6 +80,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
@@ -609,6 +610,14 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
}
// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface
func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
if am.GetPeerAppliedPostureChecksFunc != nil {
return am.GetPeerAppliedPostureChecksFunc(peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented")
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
if am.UpdateAccountSettingsFunc != nil {
@@ -706,7 +715,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
return nil
}
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
if am.UpdateIntegratedValidatorGroupsFunc != nil {
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)

View File

@@ -3,9 +3,10 @@ package mock_server
import (
"context"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/proto"
)
type ManagementServiceServerMock struct {
@@ -17,6 +18,7 @@ type ManagementServiceServerMock struct {
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
GetPKCEAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
SyncMetaFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error)
}
func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
@@ -60,3 +62,10 @@ func (m ManagementServiceServerMock) GetPKCEAuthorizationFlow(ctx context.Contex
}
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (m ManagementServiceServerMock) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
if m.SyncMetaFunc != nil {
return m.SyncMetaFunc(ctx, req)
}
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}

View File

@@ -3,6 +3,7 @@ package server
import (
"fmt"
"net"
"slices"
"strings"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -19,6 +21,11 @@ import (
type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool
}
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
@@ -551,8 +558,20 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval {
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(account)
}
}
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
@@ -563,11 +582,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
am.updateAccountPeers(account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
}
// LoginPeer logs in or registers a peer.
@@ -866,7 +885,65 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
}
for _, peer := range peers {
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}
// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer.
func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerKey)
if err != nil {
log.Errorf("failed while getting peer %s: %v", peerKey, err)
return nil, err
}
peer, err := account.FindPeerByPubKey(peerKey)
if err != nil {
return nil, status.Errorf(status.NotFound, "peer is not registered")
}
if peer == nil {
return nil, nil
}
peerPostureChecks := make(map[string]posture.Checks)
for _, policy := range account.Policies {
if !policy.Enabled {
continue
}
outerLoop:
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if !ok {
continue
}
// check if peer is in the rule source group
if slices.Contains(group.Peers, peer.ID) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureChecks := range account.PostureChecks {
if postureChecks.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureChecks
}
}
}
break outerLoop
}
}
}
}
postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
postureChecksList = append(postureChecksList, check)
}
return postureChecksList, nil
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
)
@@ -79,6 +80,13 @@ type Environment struct {
Platform string
}
// File is a file on the system.
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct { //nolint:revive
Hostname string
@@ -96,24 +104,22 @@ type PeerSystemMeta struct { //nolint:revive
SystemProductName string
SystemManufacturer string
Environment Environment `gorm:"serializer:json"`
Files []File `gorm:"serializer:json"`
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
if len(p.NetworkAddresses) != len(other.NetworkAddresses) {
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
for _, addr := range p.NetworkAddresses {
var found bool
for _, oAddr := range other.NetworkAddresses {
if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP {
found = true
continue
}
}
if !found {
return false
}
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
@@ -133,6 +139,26 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
p.Environment.Platform == other.Environment.Platform
}
func (p PeerSystemMeta) isEmpty() bool {
return p.Hostname == "" &&
p.GoOS == "" &&
p.Kernel == "" &&
p.Core == "" &&
p.Platform == "" &&
p.OS == "" &&
p.OSVersion == "" &&
p.WtVersion == "" &&
p.UIVersion == "" &&
p.KernelVersion == "" &&
len(p.NetworkAddresses) == 0 &&
p.SystemSerialNumber == "" &&
p.SystemProductName == "" &&
p.SystemManufacturer == "" &&
p.Environment.Cloud == "" &&
p.Environment.Platform == "" &&
len(p.Files) == 0
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
@@ -168,6 +194,10 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
if meta.isEmpty() {
return false
}
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion

View File

@@ -1,8 +1,9 @@
package posture
import (
"fmt"
"errors"
"net/netip"
"regexp"
"github.com/hashicorp/go-version"
@@ -14,15 +15,21 @@ const (
OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck"
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
ProcessCheckName = "ProcessCheck"
CheckActionAllow string = "allow"
CheckActionDeny string = "deny"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// Check represents an interface for performing a check on a peer.
type Check interface {
Check(peer nbpeer.Peer) (bool, error)
Name() string
Check(peer nbpeer.Peer) (bool, error)
Validate() error
}
type Checks struct {
@@ -48,6 +55,7 @@ type ChecksDefinition struct {
OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
ProcessCheck *ProcessCheck `json:",omitempty"`
}
// Copy returns a copy of a checks definition.
@@ -93,6 +101,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
}
copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
}
if cd.ProcessCheck != nil {
processCheck := cd.ProcessCheck
cdCopy.ProcessCheck = &ProcessCheck{
Processes: make([]Process, len(processCheck.Processes)),
}
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
}
return cdCopy
}
@@ -133,50 +148,49 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.PeerNetworkRangeCheck != nil {
checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
}
if pc.Checks.ProcessCheck != nil {
checks = append(checks, pc.Checks.ProcessCheck)
}
return checks
}
// Validate checks the validity of a posture checks.
func (pc *Checks) Validate() error {
if check := pc.Checks.NBVersionCheck; check != nil {
if !isVersionValid(check.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", check.Name(), check.MinVersion)
}
if pc.Name == "" {
return errors.New("posture checks name shouldn't be empty")
}
if osCheck := pc.Checks.OSVersionCheck; osCheck != nil {
if osCheck.Android != nil {
if !isVersionValid(osCheck.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", osCheck.Name(), osCheck.Android.MinVersion)
}
}
if osCheck.Ios != nil {
if !isVersionValid(osCheck.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", osCheck.Name(), osCheck.Ios.MinVersion)
}
}
if osCheck.Darwin != nil {
if !isVersionValid(osCheck.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", osCheck.Name(), osCheck.Darwin.MinVersion)
}
}
if osCheck.Linux != nil {
if !isVersionValid(osCheck.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", osCheck.Name(),
osCheck.Linux.MinKernelVersion)
}
}
if osCheck.Windows != nil {
if !isVersionValid(osCheck.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", osCheck.Name(),
osCheck.Windows.MinKernelVersion)
}
}
// posture check should contain at least one check
if pc.Checks.NBVersionCheck == nil && pc.Checks.OSVersionCheck == nil &&
pc.Checks.GeoLocationCheck == nil && pc.Checks.PeerNetworkRangeCheck == nil && pc.Checks.ProcessCheck == nil {
return errors.New("posture checks shouldn't be empty")
}
if pc.Checks.NBVersionCheck != nil {
if err := pc.Checks.NBVersionCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.OSVersionCheck != nil {
if err := pc.Checks.OSVersionCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.GeoLocationCheck != nil {
if err := pc.Checks.GeoLocationCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.PeerNetworkRangeCheck != nil {
if err := pc.Checks.PeerNetworkRangeCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.ProcessCheck != nil {
if err := pc.Checks.ProcessCheck.Validate(); err != nil {
return err
}
}
return nil
}

View File

@@ -150,9 +150,23 @@ func TestChecks_Validate(t *testing.T) {
checks Checks
expectedError bool
}{
{
name: "Empty name",
checks: Checks{},
expectedError: true,
},
{
name: "Empty checks",
checks: Checks{
Name: "Default",
Checks: ChecksDefinition{},
},
expectedError: true,
},
{
name: "Valid checks version",
checks: Checks{
Name: "default",
Checks: ChecksDefinition{
NBVersionCheck: &NBVersionCheck{
MinVersion: "0.25.0",
@@ -261,6 +275,14 @@ func TestChecks_Copy(t *testing.T) {
},
Action: CheckActionDeny,
},
ProcessCheck: &ProcessCheck{
Processes: []Process{
{
Path: "/Applications/NetBird.app/Contents/MacOS/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
},
}
checkCopy := check.Copy()

View File

@@ -2,6 +2,7 @@ package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -60,3 +61,28 @@ func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
func (g *GeoLocationCheck) Name() string {
return GeoLocationCheckName
}
func (g *GeoLocationCheck) Validate() error {
if g.Action == "" {
return fmt.Errorf("%s action shouldn't be empty", g.Name())
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, g.Action) {
return fmt.Errorf("%s action is not valid", g.Name())
}
if len(g.Locations) == 0 {
return fmt.Errorf("%s locations shouldn't be empty", g.Name())
}
for _, loc := range g.Locations {
if loc.CountryCode == "" {
return fmt.Errorf("%s country code shouldn't be empty", g.Name())
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return fmt.Errorf("%s country code must be 2 letters (ISO 3166-1 alpha-2 format)", g.Name())
}
}
return nil
}

View File

@@ -236,3 +236,81 @@ func TestGeoLocationCheck_Check(t *testing.T) {
})
}
}
func TestGeoLocationCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check GeoLocationCheck
expectedError bool
}{
{
name: "Valid location list",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: false,
},
{
name: "Invalid empty location list",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{},
},
expectedError: true,
},
{
name: "Invalid empty country name",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{
{
CityName: "Los Angeles",
},
},
},
expectedError: true,
},
{
name: "Invalid check action",
check: GeoLocationCheck{
Action: "unknownAction",
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: true,
},
{
name: "Invalid country code",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "USA",
},
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,6 +1,8 @@
package posture
import (
"fmt"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -37,3 +39,13 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
func (n *NBVersionCheck) Name() string {
return NBVersionCheckName
}
func (n *NBVersionCheck) Validate() error {
if n.MinVersion == "" {
return fmt.Errorf("%s minimum version shouldn't be empty", n.Name())
}
if !isVersionValid(n.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", n.Name(), n.MinVersion)
}
return nil
}

View File

@@ -108,3 +108,33 @@ func TestNBVersionCheck_Check(t *testing.T) {
})
}
}
func TestNBVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check NBVersionCheck
expectedError bool
}{
{
name: "Valid NBVersionCheck",
check: NBVersionCheck{MinVersion: "1.0"},
expectedError: false,
},
{
name: "Invalid NBVersionCheck",
check: NBVersionCheck{},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type PeerNetworkRangeCheck struct {
@@ -52,3 +53,19 @@ func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
func (p *PeerNetworkRangeCheck) Name() string {
return PeerNetworkRangeCheckName
}
func (p *PeerNetworkRangeCheck) Validate() error {
if p.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, p.Action) {
return fmt.Errorf("%s action is not valid", p.Name())
}
if len(p.Ranges) == 0 {
return fmt.Errorf("%s network ranges shouldn't be empty", p.Name())
}
return nil
}

View File

@@ -147,3 +147,52 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
})
}
}
func TestNetworkCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check PeerNetworkRangeCheck
expectedError bool
}{
{
name: "Valid network range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: false,
},
{
name: "Invalid empty network range",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{},
},
expectedError: true,
},
{
name: "Invalid check action",
check: PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,11 +1,13 @@
package posture
import (
"fmt"
"strings"
"github.com/hashicorp/go-version"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type MinVersionCheck struct {
@@ -48,6 +50,40 @@ func (c *OSVersionCheck) Name() string {
return OSVersionCheckName
}
func (c *OSVersionCheck) Validate() error {
emptyOS := c.Android == nil && c.Darwin == nil && c.Ios == nil &&
c.Linux == nil && c.Windows == nil
emptyMinVersion := c.Android != nil && c.Android.MinVersion == "" || c.Darwin != nil && c.Darwin.MinVersion == "" ||
c.Ios != nil && c.Ios.MinVersion == "" || c.Linux != nil && c.Linux.MinKernelVersion == "" || c.Windows != nil &&
c.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return fmt.Errorf("%s minimum version for at least one OS shouldn't be empty", c.Name())
}
if c.Android != nil && !isVersionValid(c.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", c.Name(), c.Android.MinVersion)
}
if c.Ios != nil && !isVersionValid(c.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", c.Name(), c.Ios.MinVersion)
}
if c.Darwin != nil && !isVersionValid(c.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", c.Name(), c.Darwin.MinVersion)
}
if c.Linux != nil && !isVersionValid(c.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", c.Name(),
c.Linux.MinKernelVersion)
}
if c.Windows != nil && !isVersionValid(c.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", c.Name(),
c.Windows.MinKernelVersion)
}
return nil
}
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)

View File

@@ -150,3 +150,79 @@ func TestOSVersionCheck_Check(t *testing.T) {
})
}
}
func TestOSVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check OSVersionCheck
expectedError bool
}{
{
name: "Valid linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
},
expectedError: false,
},
{
name: "Valid linux and darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: false,
},
{
name: "Invalid empty check",
check: OSVersionCheck{},
expectedError: true,
},
{
name: "Invalid empty linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
},
expectedError: true,
},
{
name: "Invalid empty linux kernel version with correct darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: true,
},
{
name: "Valid windows kernel version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{MinKernelVersion: "10.0"},
},
expectedError: false,
},
{
name: "Valid ios minimum version",
check: OSVersionCheck{
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: false,
},
{
name: "Invalid empty window version with valid ios minimum version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{},
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,64 @@
package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type Process struct {
Path string
WindowsPath string
}
type ProcessCheck struct {
Processes []Process
}
var _ Check = (*ProcessCheck)(nil)
func (p *ProcessCheck) Check(peer nbpeer.Peer) (bool, error) {
peerActiveProcesses := make([]string, 0, len(peer.Meta.Files))
for _, file := range peer.Meta.Files {
if file.ProcessIsRunning {
peerActiveProcesses = append(peerActiveProcesses, file.Path)
}
}
switch peer.Meta.GoOS {
case "darwin", "linux":
for _, process := range p.Processes {
if process.Path == "" || !slices.Contains(peerActiveProcesses, process.Path) {
return false, nil
}
}
return true, nil
case "windows":
for _, process := range p.Processes {
if process.WindowsPath == "" || !slices.Contains(peerActiveProcesses, process.WindowsPath) {
return false, nil
}
}
return true, nil
default:
return false, fmt.Errorf("unsupported peer's operating system: %s", peer.Meta.GoOS)
}
}
func (p *ProcessCheck) Name() string {
return ProcessCheckName
}
func (p *ProcessCheck) Validate() error {
if len(p.Processes) == 0 {
return fmt.Errorf("%s processes shouldn't be empty", p.Name())
}
for _, process := range p.Processes {
if process.Path == "" && process.WindowsPath == "" {
return fmt.Errorf("%s path shouldn't be empty", p.Name())
}
}
return nil
}

View File

@@ -0,0 +1,305 @@
package posture
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/peer"
)
func TestProcessCheck_Check(t *testing.T) {
tests := []struct {
name string
input peer.Peer
check ProcessCheck
wantErr bool
isValid bool
}{
{
name: "darwin with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/Applications/process1.app"},
{Path: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: true,
},
{
name: "darwin with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: true,
},
{
name: "linux with matching no running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: false},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process3"},
{Path: "/usr/bin/process4"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process1.exe"},
},
},
wantErr: false,
isValid: true,
},
{
name: "windows with darwin process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe"},
{Path: "C:\\Program Files\\process1.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/Applications/process1.app"},
{Path: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process3.exe"},
{Path: "C:\\Program Files\\process4.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "unsupported ios operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "ios",
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "C:\\Program Files\\process1.exe"},
{Path: "C:\\Program Files\\process2.exe"},
},
},
wantErr: true,
isValid: false,
},
{
name: "unsupported android operating system with matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "android",
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: true,
isValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.isValid, isValid)
})
}
}
func TestProcessCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check ProcessCheck
expectedError bool
}{
{
name: "Valid unix and windows processes",
check: ProcessCheck{
Processes: []Process{
{
Path: "/usr/local/bin/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Valid unix process",
check: ProcessCheck{
Processes: []Process{
{
Path: "/usr/local/bin/netbird",
},
},
},
expectedError: false,
},
{
name: "Valid windows process",
check: ProcessCheck{
Processes: []Process{
{
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Invalid empty processes",
check: ProcessCheck{
Processes: []Process{},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -52,7 +52,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.BadRequest, err.Error())
return status.Errorf(status.InvalidArgument, err.Error())
}
exists, uniqName := am.savePostureChecks(account, postureChecks)

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C:
select {
case <-cancel:
log.Debugf("scheduled job %s was canceled, stop timer", ID)
log.Tracef("scheduled job %s was canceled, stop timer", ID)
ticker.Stop()
return
default:
log.Debugf("time to do a scheduled job %s", ID)
log.Tracef("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID)
log.Tracef("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn)
}
case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID)
log.Tracef("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
// operate over a fresh copy as we will modify its fields
accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
}
for id, peer := range account.Peers {
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
//we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
}
for id, user := range account.Users {
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
accCopy.UsersG = append(accCopy.UsersG, *user)
}
for id, group := range account.Groups {
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
//we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
}
for id, route := range account.Routes {
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
//we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
//we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil {
return result.Error
}
return nil
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err
}
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {

View File

@@ -2,7 +2,12 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account)
require.NoError(t, err)
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user := NewAdminUser(testUserID, "account_id")
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
return &User{
Id: id,
Id: ID,
AccountID: accountID,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(ID, accountID string) *User {
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
func NewAdminUser(ID, accountID string) *User {
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(ID, accountID string) *User {
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
err := store.SaveAccount(account)
if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID)
user := NewAdminUser(mockUserID, mockAccountID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID)
user = NewRegularUser(mockUserID, mockAccountID)
assert.False(t, user.HasAdminPower())
}
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account)
if err != nil {

View File

@@ -3,6 +3,8 @@ package grpc
import (
"context"
"net"
"os/user"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -12,6 +14,20 @@ import (
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)

View File

@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
@@ -56,7 +60,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
connID := GenerateConnID()
if dialerDialHooks != nil {
if err := calliDialerHooks(ctx, connID, address, resolver); err != nil {
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
}
@@ -97,7 +101,7 @@ func (c *Conn) Close() error {
return err
}
func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
@@ -123,6 +127,10 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string,
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr

View File

@@ -8,6 +8,7 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)

View File

@@ -1,10 +1,16 @@
package net
import "github.com/google/uuid"
import (
"os"
"github.com/google/uuid"
)
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -15,3 +21,7 @@ type ConnectionID string
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
setErr = SetSocketOpt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}