mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
16 Commits
use-conntr
...
refactor/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4fea16c9f | ||
|
|
b3ab142d2e | ||
|
|
de457788ba | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -32,7 +34,7 @@ import (
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -67,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -90,13 +92,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
|
||||
@@ -14,8 +14,13 @@ import (
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
@@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
ipStr := ip.String()
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
@@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
return false
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
@@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) {
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
@@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
@@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -24,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -15,7 +14,7 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
|
||||
@@ -99,6 +99,7 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
oauth2.SetAuthURLParam("prompt", "login"),
|
||||
)
|
||||
|
||||
return AuthFlowInfo{
|
||||
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
|
||||
@@ -38,7 +38,6 @@ const (
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
@@ -124,18 +123,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
||||
return fmt.Errorf("set link as default dns router: %w", err)
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: nbdns.RootZone,
|
||||
MatchOnly: true,
|
||||
})
|
||||
s.routingAll = true
|
||||
} else if s.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
|
||||
return fmt.Errorf("remove link as default dns router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
state := &ShutdownState{
|
||||
|
||||
@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1400,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Relay: &server.Relay{
|
||||
config := &types.Config{
|
||||
Stuns: []*types.Host{},
|
||||
TURNConfig: &types.TURNConfig{},
|
||||
Relay: &types.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &server.Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
@@ -1438,6 +1439,8 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
@@ -1446,7 +1449,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -19,11 +19,9 @@ import (
|
||||
type rcvChan chan *types.EventFields
|
||||
type Logger struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
dnsCollection atomic.Bool
|
||||
@@ -31,12 +29,9 @@ type Logger struct {
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
@@ -109,7 +104,7 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Disable() {
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.Store.Close()
|
||||
}
|
||||
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
|
||||
|
||||
l.enabled.Store(false)
|
||||
l.mux.Lock()
|
||||
if l.cancelReceiver != nil {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
l.exitNodeCollection.Store(exitNodeCollection)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background(), nil, net.IPNet{})
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
|
||||
// test disable
|
||||
logger.Disable()
|
||||
logger.Close()
|
||||
wait()
|
||||
logger.StoreEvent(event)
|
||||
wait()
|
||||
|
||||
@@ -27,18 +27,18 @@ type Manager struct {
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
ctx context.Context
|
||||
receiverClient *client.GRPCClient
|
||||
publicKey []byte
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
||||
return &Manager{
|
||||
logger: flowLogger,
|
||||
conntrack: ct,
|
||||
ctx: ctx,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
if err := m.resetClient(); err != nil {
|
||||
return fmt.Errorf("reset client: %w", err)
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetClient() error {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
|
||||
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
||||
|
||||
changed := previous != nil && update.Enabled != previous.Enabled
|
||||
if update.Enabled {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
if changed {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
}
|
||||
return m.enableFlow(previous)
|
||||
}
|
||||
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
err := m.disableFlow()
|
||||
if err != nil {
|
||||
log.Errorf("failed to disable netflow manager: %v", err)
|
||||
if changed {
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
}
|
||||
return err
|
||||
return m.disableFlow()
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
func (m *Manager) startSender() {
|
||||
func (m *Manager) startSender(ctx context.Context) {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
|
||||
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
id, err := uuid.FromBytes(ack.EventId)
|
||||
if err != nil {
|
||||
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
||||
|
||||
200
client/internal/netflow/manager_test.go
Normal file
200
client/internal/netflow/manager_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type mockIFaceMapper struct {
|
||||
address wgaddr.Address
|
||||
isUserspaceBind bool
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
return m.isUserspaceBind
|
||||
}
|
||||
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
statusRecorder := peer.NewRecorder("")
|
||||
|
||||
manager := NewManager(mockIFace, publicKey, statusRecorder)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.FlowConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "disabled config",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled config with minimal valid settings",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: true,
|
||||
URL: "https://example.com",
|
||||
TokenPayload: "test-payload",
|
||||
TokenSignature: "test-signature",
|
||||
Interval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.Update(tc.config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tc.config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, manager.flowConfig)
|
||||
|
||||
if tc.config.Enabled {
|
||||
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
|
||||
}
|
||||
|
||||
if tc.config.URL != "" {
|
||||
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
|
||||
}
|
||||
|
||||
if tc.config.TokenPayload != "" {
|
||||
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
manager := NewManager(mockIFace, publicKey, nil)
|
||||
|
||||
// First update with tokens
|
||||
initialConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
TokenPayload: "initial-payload",
|
||||
TokenSignature: "initial-signature",
|
||||
}
|
||||
|
||||
err := manager.Update(initialConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second update without tokens should preserve them
|
||||
updatedConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
err = manager.Update(updatedConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens were preserved
|
||||
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
|
||||
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
|
||||
}
|
||||
|
||||
func TestManager_NeedsNewClient(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous *types.FlowConfig
|
||||
current *types.FlowConfig
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil previous config",
|
||||
previous: nil,
|
||||
current: &types.FlowConfig{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "previous disabled",
|
||||
previous: &types.FlowConfig{Enabled: false},
|
||||
current: &types.FlowConfig{Enabled: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different URL",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenPayload",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenSignature",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same config",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only interval changed",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager.flowConfig = tc.current
|
||||
result := manager.needsNewClient(tc.previous)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -120,9 +120,6 @@ type FlowLogger interface {
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
Enable()
|
||||
// Disable disables the flow logger receiver
|
||||
Disable()
|
||||
|
||||
// UpdateConfig updates the flow manager configuration
|
||||
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
||||
}
|
||||
|
||||
@@ -103,6 +103,7 @@ type Conn struct {
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
connIDRelay nbnet.ConnectionID
|
||||
connIDICE nbnet.ConnectionID
|
||||
@@ -211,6 +212,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
func (conn *Conn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
conn.log.Infof("close peer connection")
|
||||
@@ -252,6 +254,7 @@ func (conn *Conn) Close() {
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
conn.log.Infof("peer connection has been closed")
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
@@ -362,6 +365,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
@@ -407,7 +411,12 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
} else {
|
||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||
@@ -476,7 +485,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
@@ -502,6 +516,7 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
|
||||
@@ -586,9 +586,8 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
if d.localPeer.Routes == nil {
|
||||
@@ -596,8 +595,6 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
}
|
||||
|
||||
d.localPeer.Routes[route] = struct{}{}
|
||||
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
@@ -606,14 +603,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
}
|
||||
|
||||
delete(d.localPeer.Routes, route)
|
||||
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
}
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
|
||||
@@ -32,7 +32,6 @@ type WGWatcher struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
w.ctxLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
w.ctxLock.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.waitGroup.Add(1)
|
||||
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
}
|
||||
|
||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||
@@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
|
||||
|
||||
w.ctxCancel()
|
||||
w.ctxCancel = nil
|
||||
w.waitGroup.Wait()
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
defer w.waitGroup.Done()
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
@@ -79,10 +79,11 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
go watcher.EnableWgWatcher(ctx, func() {})
|
||||
time.Sleep(1 * time.Second)
|
||||
watcher.DisableWgWatcher()
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
iFaceDiscover: ifaceDiscover,
|
||||
statusRecorder: statusRecorder,
|
||||
hasRelayOnLocally: hasRelayOnLocally,
|
||||
lastKnownState: ice.ConnectionStateDisconnected,
|
||||
}
|
||||
|
||||
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
||||
@@ -213,7 +214,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
|
||||
},
|
||||
logFile: logFile,
|
||||
persistNetworkMap: true,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,9 +137,6 @@ func (s *Server) Start() error {
|
||||
|
||||
s.config = config
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
|
||||
@@ -622,9 +620,6 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
@@ -692,9 +687,6 @@ func (s *Server) Status(
|
||||
|
||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
|
||||
@@ -3,27 +3,32 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
@@ -83,6 +88,72 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
s := New(ctx, t.TempDir()+"/config.json", "console")
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &internal.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
upReq := &daemonProto.UpRequest{}
|
||||
_, err = s.Up(upCtx, upReq)
|
||||
|
||||
assert.Contains(t, err.Error(), "NeedsLogin")
|
||||
}
|
||||
|
||||
type mockSubscribeEventsServer struct {
|
||||
ctx context.Context
|
||||
sentEvents []*daemonProto.SystemEvent
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (m *mockSubscribeEventsServer) Send(event *daemonProto.SystemEvent) error {
|
||||
m.sentEvents = append(m.sentEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSubscribeEventsServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func TestServer_SubcribeEvents(t *testing.T) {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
s := New(ctx, t.TempDir()+"/config.json", "console")
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &internal.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
upReq := &daemonProto.SubscribeRequest{}
|
||||
mockServer := &mockSubscribeEventsServer{
|
||||
ctx: ctx,
|
||||
sentEvents: make([]*daemonProto.SystemEvent, 0),
|
||||
ServerStream: nil,
|
||||
}
|
||||
err = s.SubscribeEvents(upReq, mockServer)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
@@ -97,10 +168,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Signal: &server.Host{
|
||||
config := &types.Config{
|
||||
Stuns: []*types.Host{},
|
||||
TURNConfig: &types.TURNConfig{},
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
@@ -129,11 +200,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
@@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error {
|
||||
defer c.streamMu.Unlock()
|
||||
|
||||
c.stream = nil
|
||||
return c.clientConn.Close()
|
||||
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("close client connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
backOff := defaultBackoff(ctx, interval)
|
||||
operation := func() error {
|
||||
err := c.establishStreamAndReceive(ctx, msgHandler)
|
||||
if err != nil {
|
||||
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
||||
}
|
||||
log.Errorf("receive failed: %v", err)
|
||||
return fmt.Errorf("receive: %w", err)
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(operation, backOff); err != nil {
|
||||
|
||||
256
flow/client/client_test.go
Normal file
256
flow/client/client_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
flow "github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
)
|
||||
|
||||
type testServer struct {
|
||||
proto.UnimplementedFlowServiceServer
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &testServer{
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: listener.Addr().String(),
|
||||
}
|
||||
|
||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||
|
||||
go func() {
|
||||
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.grpcSrv.Stop()
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
for {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !event.IsInitiator {
|
||||
select {
|
||||
case s.events <- event:
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: event.EventId,
|
||||
}
|
||||
select {
|
||||
case s.acks <- ack:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ack := <-s.acks:
|
||||
if err := stream.Send(ack); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceive(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
receivedAcks := make(map[string]bool)
|
||||
receiveDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||
id := string(msg.EventId)
|
||||
receivedAcks[id] = true
|
||||
|
||||
if len(receivedAcks) >= 3 {
|
||||
close(receiveDone)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
eventID := uuid.New().String()
|
||||
|
||||
// Create acknowledgment and send it to the flow through our test server
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: []byte(eventID),
|
||||
}
|
||||
|
||||
select {
|
||||
case server.acks <- ack:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout sending ack")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for acks to be processed")
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(receivedAcks))
|
||||
}
|
||||
|
||||
func TestReceive_ContextCancellation(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
handlerCalled := false
|
||||
msgHandler := func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator {
|
||||
handlerCalled = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Receive(ctx, 1*time.Second, msgHandler)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
assert.False(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ackReceived := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
|
||||
if len(ack.EventId) > 0 && !ack.IsInitiator {
|
||||
close(ackReceived)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
testEvent := &proto.FlowEvent{
|
||||
EventId: []byte("test-event-id"),
|
||||
PublicKey: []byte("test-public-key"),
|
||||
FlowFields: &proto.FlowFields{
|
||||
FlowId: []byte("test-flow-id"),
|
||||
Protocol: 6,
|
||||
SourceIp: []byte{192, 168, 1, 1},
|
||||
DestIp: []byte{192, 168, 1, 2},
|
||||
ConnectionInfo: &proto.FlowFields_PortInfo{
|
||||
PortInfo: &proto.PortInfo{
|
||||
SourcePort: 12345,
|
||||
DestPort: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = client.Send(testEvent)
|
||||
require.NoError(t, err)
|
||||
|
||||
var receivedEvent *proto.FlowEvent
|
||||
select {
|
||||
case receivedEvent = <-server.events:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for event to be received by server")
|
||||
}
|
||||
|
||||
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
|
||||
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
|
||||
|
||||
select {
|
||||
case <-ackReceived:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ack to be received by flow")
|
||||
}
|
||||
}
|
||||
2
go.mod
2
go.mod
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 h1:uxxbLPXQgC9VO15epNPtrD6zazyd5rZeqC5hQSmCdZU=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -50,7 +51,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
level, _ := log.ParseLevel("debug")
|
||||
log.SetLevel(level)
|
||||
|
||||
config := &mgmt.Config{}
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../server/testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -74,6 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
@@ -87,7 +89,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -34,7 +34,9 @@ import (
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
@@ -50,7 +52,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -70,7 +71,7 @@ var (
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
config *server.Config
|
||||
config *types.Config
|
||||
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@@ -101,9 +102,9 @@ var (
|
||||
// detect whether user specified a port
|
||||
userPort := cmd.Flag("port").Changed
|
||||
|
||||
config, err = loadMgmtConfig(ctx, mgmtConfig)
|
||||
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
|
||||
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err)
|
||||
}
|
||||
|
||||
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
|
||||
@@ -183,7 +184,7 @@ var (
|
||||
if config.DataStoreEncryptionKey != key {
|
||||
log.WithContext(ctx).Infof("update config with activity store key")
|
||||
config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(ctx, mgmtConfig, config)
|
||||
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write out store encryption key: %s", err)
|
||||
}
|
||||
@@ -201,15 +202,14 @@ var (
|
||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||
}
|
||||
|
||||
permissionsManager := integrations.InitPermissionsManager(store)
|
||||
userManager := users.NewManager(store)
|
||||
extraSettingsManager := integrations.NewManager(eventStore)
|
||||
settingsManager := settings.NewManager(store, userManager, extraSettingsManager)
|
||||
permissionsManager := permissions.NewManager(userManager, settingsManager)
|
||||
settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
proxyController := integrations.NewController(store)
|
||||
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
@@ -486,8 +486,8 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
|
||||
})
|
||||
}
|
||||
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
|
||||
loadedConfig := &server.Config{}
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
|
||||
loadedConfig := &types.Config{}
|
||||
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -522,7 +522,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
|
||||
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
|
||||
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
|
||||
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) {
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
@@ -539,7 +539,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope
|
||||
}
|
||||
}
|
||||
|
||||
@@ -560,7 +560,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
|
||||
return loadedConfig, err
|
||||
}
|
||||
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
@@ -636,7 +636,7 @@ func handleRebrand(cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
if mgmtConfig == defaultMgmtConfig {
|
||||
if types.MgmtConfigPath == defaultMgmtConfig {
|
||||
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
|
||||
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
|
||||
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -19,7 +20,6 @@ const (
|
||||
var (
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
mgmtConfig string
|
||||
logLevel string
|
||||
logFile string
|
||||
disableMetrics bool
|
||||
@@ -56,7 +56,7 @@ func init() {
|
||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
|
||||
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -89,6 +90,8 @@ type DefaultAccountManager struct {
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
metrics telemetry.AppMetrics
|
||||
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -156,6 +159,7 @@ func BuildManager(
|
||||
metrics telemetry.AppMetrics,
|
||||
proxyController port_forwarding.Controller,
|
||||
settingsManager settings.Manager,
|
||||
permissionsManager permissions.Manager,
|
||||
) (*DefaultAccountManager, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -180,6 +184,7 @@ func BuildManager(
|
||||
requestBuffer: NewAccountRequestBuffer(ctx, store),
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
accountsCounter, err := store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
@@ -253,13 +258,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Write)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
|
||||
@@ -503,16 +508,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Accounts, permissions.Write)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
|
||||
}
|
||||
|
||||
if user.Role != types.UserRoleOwner {
|
||||
if !allowed {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
|
||||
@@ -542,14 +543,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
}
|
||||
|
||||
userInfo, ok := userInfosMap[userID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
|
||||
}
|
||||
|
||||
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
||||
return err
|
||||
if ok {
|
||||
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.Store.DeleteAccount(ctx, account)
|
||||
@@ -1027,8 +1026,8 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
@@ -1061,8 +1060,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && userAuth.Invited {
|
||||
@@ -1521,7 +1520,11 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
@@ -1606,3 +1609,113 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma
|
||||
func (am *DefaultAccountManager) GetStore() store.Store {
|
||||
return am.Store
|
||||
}
|
||||
|
||||
// Creates account by private domain.
|
||||
// Expects domain value to be a valid and a private dns domain.
|
||||
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
defer cancel()
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
|
||||
}
|
||||
|
||||
// retry twice for new ID clashes
|
||||
for range 2 {
|
||||
accountId := xid.New().String()
|
||||
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
|
||||
if err != nil || exists {
|
||||
continue
|
||||
}
|
||||
|
||||
network := types.NewNetwork()
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
dnsSettings := types.DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
}
|
||||
|
||||
newAccount := &types.Account{
|
||||
Id: accountId,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
// @todo check if using the MSP owner id here is ok
|
||||
CreatedBy: initiatorId,
|
||||
Domain: domain,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
IsDomainPrimaryAccount: false,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
RoutingPeerDNSResolutionEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.IsDomainPrimaryAccount {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// additional check to ensure there is only one account for this domain at the time of update
|
||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
|
||||
}
|
||||
|
||||
account.IsDomainPrimaryAccount = true
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
|
||||
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
@@ -111,4 +111,7 @@ type Manager interface {
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
GetStore() store.Store
|
||||
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
@@ -2793,15 +2794,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type TB interface {
|
||||
Cleanup(func())
|
||||
Helper()
|
||||
TempDir() string
|
||||
Errorf(format string, args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
}
|
||||
//type TB interface {
|
||||
// Cleanup(func())
|
||||
// Helper()
|
||||
// TempDir() string
|
||||
// Errorf(format string, args ...interface{})
|
||||
// Fatalf(format string, args ...interface{})
|
||||
//}
|
||||
|
||||
func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
@@ -2815,6 +2816,8 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
@@ -2828,7 +2831,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
Return(false, nil).
|
||||
AnyTimes()
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2836,7 +2839,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t TB) (store.Store, error) {
|
||||
func createStore(t testing.TB) (store.Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
@@ -3150,3 +3153,51 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
||||
assert.Equal(t, initiatorId, account.CreatedBy)
|
||||
assert.Equal(t, 1, len(account.Groups))
|
||||
assert.Equal(t, 0, len(account.Users))
|
||||
assert.Equal(t, 0, len(account.SetupKeys))
|
||||
|
||||
// retry should fail
|
||||
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
// retry should fail
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
}
|
||||
|
||||
@@ -67,8 +67,8 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -89,8 +89,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -210,13 +211,14 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func isEnabled() bool {
|
||||
@@ -19,16 +21,12 @@ func isEnabled() bool {
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -58,6 +56,11 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
|
||||
filtered = append(filtered, event)
|
||||
}
|
||||
|
||||
err = am.fillEventsWithUserInfo(ctx, events, accountID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
@@ -79,3 +82,156 @@ func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, ta
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type eventUserInfo struct {
|
||||
email string
|
||||
name string
|
||||
accountId string
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) error {
|
||||
eventUserInfo, err := am.getEventsUserInfo(ctx, events, accountId, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
if !fillEventInitiatorInfo(eventUserInfo, event) {
|
||||
log.WithContext(ctx).Warnf("failed to resolve user info for initiator: %s", event.InitiatorID)
|
||||
}
|
||||
|
||||
fillEventTargetInfo(eventUserInfo, event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) (map[string]eventUserInfo, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// @note check whether using a external initiator user here is an issue
|
||||
userInfos, err := am.BuildUserInfosForAccount(ctx, accountId, user.Id, accountUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eventUserInfos := make(map[string]eventUserInfo)
|
||||
for i, k := range userInfos {
|
||||
eventUserInfos[i] = eventUserInfo{
|
||||
email: k.Email,
|
||||
name: k.Name,
|
||||
accountId: accountId,
|
||||
}
|
||||
}
|
||||
|
||||
externalUserIds := []string{}
|
||||
for _, event := range events {
|
||||
if _, ok := eventUserInfos[event.InitiatorID]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if event.InitiatorID == activity.SystemInitiator ||
|
||||
event.InitiatorID == accountId ||
|
||||
event.Activity == activity.PeerAddedWithSetupKey {
|
||||
// @todo other events to be excluded if never initiated by a user
|
||||
continue
|
||||
}
|
||||
|
||||
externalUserIds = append(externalUserIds, event.InitiatorID)
|
||||
}
|
||||
|
||||
if len(externalUserIds) == 0 {
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, user)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, user *types.User) (map[string]eventUserInfo, error) {
|
||||
externalAccountId := ""
|
||||
fetched := make(map[string]struct{})
|
||||
externalUsers := []*types.User{}
|
||||
for _, id := range externalUserIds {
|
||||
if _, ok := fetched[id]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
if err != nil {
|
||||
// @todo consider logging
|
||||
continue
|
||||
}
|
||||
|
||||
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
|
||||
return nil, fmt.Errorf("multiple external user accounts in events")
|
||||
}
|
||||
|
||||
if externalAccountId == "" {
|
||||
externalAccountId = externalUser.AccountID
|
||||
}
|
||||
|
||||
fetched[id] = struct{}{}
|
||||
externalUsers = append(externalUsers, externalUser)
|
||||
}
|
||||
|
||||
// if we couldn't determine an account, return what we have
|
||||
if externalAccountId == "" {
|
||||
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, user.Id, externalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, k := range externalUserInfos {
|
||||
eventUserInfos[i] = eventUserInfo{
|
||||
email: k.Email,
|
||||
name: k.Name,
|
||||
accountId: externalAccountId,
|
||||
}
|
||||
}
|
||||
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
||||
func fillEventTargetInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) {
|
||||
userInfo, ok := eventUserInfo[event.TargetID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if event.Meta == nil {
|
||||
event.Meta = make(map[string]any)
|
||||
}
|
||||
|
||||
event.Meta["email"] = userInfo.email
|
||||
event.Meta["username"] = userInfo.name
|
||||
}
|
||||
|
||||
func fillEventInitiatorInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) bool {
|
||||
userInfo, ok := eventUserInfo[event.InitiatorID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if event.InitiatorEmail == "" {
|
||||
event.InitiatorEmail = userInfo.email
|
||||
}
|
||||
|
||||
if event.InitiatorName == "" {
|
||||
event.InitiatorName = userInfo.name
|
||||
}
|
||||
|
||||
if event.AccountID != userInfo.accountId {
|
||||
if event.Meta == nil {
|
||||
event.Meta = make(map[string]any)
|
||||
}
|
||||
|
||||
event.Meta["external"] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -10,13 +10,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
type GroupLinkError struct {
|
||||
@@ -35,8 +35,8 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -83,8 +83,8 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -215,8 +215,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -498,6 +498,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"network router", linkedRouter.ID}
|
||||
}
|
||||
|
||||
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||
}
|
||||
|
||||
@@ -613,6 +617,22 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
|
||||
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
if slices.Contains(router.PeerGroups, groupID) {
|
||||
return true, router
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
@@ -637,6 +657,9 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
|
||||
@@ -12,6 +12,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -414,6 +421,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupD",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupE",
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -673,4 +685,51 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to network router should update account peers and send peer update
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
|
||||
ID: "network_test",
|
||||
AccountID: account.Id,
|
||||
Name: "network_test",
|
||||
Description: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routersManager.CreateRouter(context.Background(), userID, &routerTypes.NetworkRouter{
|
||||
ID: "router_test",
|
||||
NetworkID: network.ID,
|
||||
AccountID: account.Id,
|
||||
PeerGroups: []string{"groupE"},
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupE",
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -40,7 +41,7 @@ type GRPCServer struct {
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
config *types.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
@@ -51,7 +52,7 @@ type GRPCServer struct {
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
config *types.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
@@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case UDP:
|
||||
case types.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case DTLS:
|
||||
case types.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case HTTP:
|
||||
case types.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case HTTPS:
|
||||
case types.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case TCP:
|
||||
case types.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
@@ -626,10 +625,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
@@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
@@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
|
||||
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
|
||||
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -47,66 +46,15 @@ func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
events := make([]*api.Event, len(accountEvents))
|
||||
for i, e := range accountEvents {
|
||||
events[i] = toEventResponse(e)
|
||||
}
|
||||
|
||||
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, events)
|
||||
}
|
||||
|
||||
func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
|
||||
// build email, name maps based on users
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
emails := make(map[string]string)
|
||||
names := make(map[string]string)
|
||||
for _, ui := range userInfos {
|
||||
emails[ui.ID] = ui.Email
|
||||
names[ui.ID] = ui.Name
|
||||
}
|
||||
|
||||
var ok bool
|
||||
for _, event := range events {
|
||||
// fill initiator
|
||||
if event.InitiatorEmail == "" {
|
||||
event.InitiatorEmail, ok = emails[event.InitiatorId]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
}
|
||||
}
|
||||
|
||||
if event.InitiatorName == "" {
|
||||
// here to allowed to be empty because in the first release we did not store the name
|
||||
event.InitiatorName = names[event.InitiatorId]
|
||||
}
|
||||
|
||||
// fill target meta
|
||||
email, ok := emails[event.TargetId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
event.Meta["email"] = email
|
||||
|
||||
username, ok := names[event.TargetId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
event.Meta["username"] = username
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toEventResponse(event *activity.Event) *api.Event {
|
||||
meta := make(map[string]string)
|
||||
if event.Meta != nil {
|
||||
|
||||
@@ -250,7 +250,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -258,7 +258,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
|
||||
@@ -122,6 +122,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
switch id {
|
||||
case adminUser:
|
||||
return account.Users[adminUser], nil
|
||||
case regularUser:
|
||||
return account.Users[regularUser], nil
|
||||
case serviceUser:
|
||||
return account.Users[serviceUser], nil
|
||||
default:
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
},
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
|
||||
@@ -301,7 +301,7 @@ func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -142,6 +142,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if user.AccountID != impersonate[0] {
|
||||
return r, fmt.Errorf("token is not valid for this account")
|
||||
}
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return r, err
|
||||
|
||||
@@ -239,16 +239,9 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid PAT Token ignores child",
|
||||
name: "Valid PAT Token returns forbidden for child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsPAT: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid JWT Token",
|
||||
|
||||
@@ -16,19 +16,18 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -124,8 +123,9 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
validatorMock := server.MocIntegratedValidator{}
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}))
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager)
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManagerMock)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
@@ -143,7 +143,6 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
resourcesManagerMock := resources.NewManagerMock()
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
|
||||
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManagerMock, peersManager, settingsManager)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -93,21 +94,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
|
||||
|
||||
func Test_SyncProtocol(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
|
||||
Stuns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.netbird.io:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TURNConfig: &types.TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Turns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.netbird.io:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "signal.netbird.io:10000",
|
||||
},
|
||||
@@ -330,7 +331,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *DeviceAuthorizationFlow
|
||||
inputFlow *types.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
@@ -345,9 +346,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &DeviceAuthorizationFlow{
|
||||
inputFlow: &types.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: ProviderConfig{
|
||||
ProviderConfig: types.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
@@ -356,9 +357,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &DeviceAuthorizationFlow{
|
||||
inputFlow: &types.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: ProviderConfig{
|
||||
ProviderConfig: types.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
@@ -379,7 +380,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &GRPCServer{
|
||||
wgKey: testingServerKey,
|
||||
config: &Config{
|
||||
config: &types.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
@@ -410,7 +411,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
|
||||
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
@@ -431,6 +432,8 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
@@ -441,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
|
||||
Return(&types.Settings{}, nil)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
@@ -506,21 +509,21 @@ func testSyncStatusRace(t *testing.T) {
|
||||
t.Skip()
|
||||
dir := t.TempDir()
|
||||
|
||||
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
|
||||
Stuns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.netbird.io:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TURNConfig: &types.TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Turns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.netbird.io:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "signal.netbird.io:10000",
|
||||
},
|
||||
@@ -678,21 +681,21 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
|
||||
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
|
||||
Stuns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.netbird.io:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TURNConfig: &types.TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Turns: []*types.Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.netbird.io:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "signal.netbird.io:10000",
|
||||
},
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -58,7 +59,7 @@ func setupTest(t *testing.T) *testSuite {
|
||||
t.Fatalf("failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
config := &server.Config{}
|
||||
config := &types.Config{}
|
||||
_, err = util.ReadJson("testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read management.json: %v", err)
|
||||
@@ -156,7 +157,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
|
||||
|
||||
func startServer(
|
||||
t *testing.T,
|
||||
config *server.Config,
|
||||
config *types.Config,
|
||||
dataDir string,
|
||||
testFile string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
@@ -194,6 +195,7 @@ func startServer(
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
accountManager, err := server.BuildManager(
|
||||
context.Background(),
|
||||
str,
|
||||
@@ -208,6 +210,7 @@ func startServer(
|
||||
metrics,
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManagerMock,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
@@ -300,6 +303,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
|
||||
Protocol: mgmtProto.HostConfig_UDP,
|
||||
}
|
||||
|
||||
expectedRelayHost := &mgmtProto.RelayConfig{
|
||||
Urls: []string{"rel://test.com:3535"},
|
||||
}
|
||||
|
||||
assert.NotNil(t, resp.NetbirdConfig)
|
||||
assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig)
|
||||
assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig)
|
||||
@@ -307,6 +314,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
|
||||
actualTURN := resp.NetbirdConfig.Turns[0]
|
||||
assert.Greater(t, len(actualTURN.User), 0)
|
||||
assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost)
|
||||
assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1)
|
||||
assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls)
|
||||
assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -49,7 +48,7 @@ type properties map[string]interface{}
|
||||
// DataSource metric data source
|
||||
type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() store.Engine
|
||||
GetStoreEngine() types.Engine
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
func (mockDatasource) GetStoreEngine() store.Engine {
|
||||
return store.FileStoreEngine
|
||||
func (mockDatasource) GetStoreEngine() types.Engine {
|
||||
return types.FileStoreEngine
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
@@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
|
||||
}
|
||||
|
||||
if properties["store_engine"] != store.FileStoreEngine {
|
||||
if properties["store_engine"] != types.FileStoreEngine {
|
||||
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
|
||||
}
|
||||
|
||||
|
||||
@@ -112,6 +112,9 @@ type MockAccountManager struct {
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@@ -847,3 +850,24 @@ func (am *MockAccountManager) GetStore() store.Store {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
||||
if am.CreateAccountByPrivateDomainFunc != nil {
|
||||
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
if am.UpdateToPrimaryAccountFunc != nil {
|
||||
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {
|
||||
if am.GetOwnerInfoFunc != nil {
|
||||
return am.GetOwnerInfoFunc(ctx, accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -46,8 +46,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNSGroup := &nbdns.NameServerGroup{
|
||||
@@ -108,8 +108,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
@@ -159,8 +159,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
@@ -203,8 +203,8 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -774,11 +775,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
|
||||
@@ -37,8 +37,8 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
@@ -188,8 +188,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
@@ -321,8 +321,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,7 +621,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user last login: %w", err)
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
} else {
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
@@ -1054,7 +1054,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
|
||||
|
||||
err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin())
|
||||
if err != nil {
|
||||
return err
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
@@ -1099,8 +1099,8 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
@@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
@@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -723,7 +724,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
b.Helper()
|
||||
|
||||
manager, err := createManager(b)
|
||||
@@ -998,6 +999,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 1},
|
||||
{"Medium", 500, 1},
|
||||
{"Large", 1000, 1},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSyncResponse(t *testing.T) {
|
||||
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
@@ -1008,16 +1056,16 @@ func TestToSyncResponse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
config := &types.Config{
|
||||
Signal: &types.Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
|
||||
TURNConfig: &types.TURNConfig{
|
||||
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
@@ -1217,7 +1265,8 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1285,7 +1334,8 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1356,7 +1406,8 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
@@ -5,10 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
type Module string
|
||||
@@ -17,6 +16,8 @@ const (
|
||||
Networks Module = "networks"
|
||||
Peers Module = "peers"
|
||||
Groups Module = "groups"
|
||||
Settings Module = "settings"
|
||||
Accounts Module = "accounts"
|
||||
)
|
||||
|
||||
type Operation string
|
||||
@@ -28,42 +29,50 @@ const (
|
||||
|
||||
type Manager interface {
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error)
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
userManager users.Manager
|
||||
settingsManager settings.Manager
|
||||
store store.Store
|
||||
}
|
||||
|
||||
type managerMock struct {
|
||||
}
|
||||
|
||||
func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager {
|
||||
func NewManager(store store.Store) Manager {
|
||||
return &managerImpl{
|
||||
userManager: userManager,
|
||||
settingsManager: settingsManager,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
|
||||
user, err := m.userManager.GetUser(ctx, userID)
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return false, errors.New("user not found")
|
||||
return false, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return false, errors.New("user does not belong to account")
|
||||
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch module {
|
||||
case Accounts:
|
||||
if operation == Write && user.Role != types.UserRoleOwner {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
}
|
||||
|
||||
switch user.Role {
|
||||
case types.UserRoleAdmin, types.UserRoleOwner:
|
||||
return true, nil
|
||||
case types.UserRoleUser:
|
||||
return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation)
|
||||
return m.validateRegularUserPermissions(ctx, accountID, module, operation)
|
||||
case types.UserRoleBillingAdmin:
|
||||
return false, nil
|
||||
default:
|
||||
@@ -71,8 +80,8 @@ func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
|
||||
settings, err := m.settingsManager.GetSettings(ctx, accountID, activity.SystemInitiator)
|
||||
func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID string, module Module, operation Operation) (bool, error) {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get settings: %w", err)
|
||||
}
|
||||
@@ -91,13 +100,30 @@ func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accoun
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &managerMock{}
|
||||
}
|
||||
|
||||
func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
|
||||
if userID == "allowedUser" {
|
||||
switch userID {
|
||||
case "a23efe53-63fb-11ec-90d6-0242ac120003", "allowedUser", "testingUser", "account_creator", "serviceUserID", "test_user":
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *managerMock) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
// @note managers explicitly checked this, so should the mock
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -100,8 +100,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -148,8 +148,8 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
|
||||
@@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
@@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
@@ -99,8 +99,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
@@ -141,8 +141,8 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
|
||||
@@ -25,7 +25,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
@@ -119,16 +123,18 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
@@ -236,16 +242,18 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
@@ -310,6 +318,15 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -342,7 +359,11 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1116,14 +1117,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
|
||||
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 3, "after peer deletion group should have 3 client routes")
|
||||
|
||||
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have only 2 routes")
|
||||
|
||||
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
@@ -1134,7 +1135,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
|
||||
peer2RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
|
||||
assert.Len(t, peer2RoutesAfterAdd.Routes, 3, "HA route should have 3 client routes")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
@@ -1259,6 +1260,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
@@ -1281,7 +1283,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
AnyTimes().
|
||||
Return(&types.ExtraSettings{}, nil)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1492,7 +1494,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
{
|
||||
ID: routeGroupHA1,
|
||||
Name: routeGroupHA1,
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, // we have one non Linux peer, see peer3
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: routeGroupHA2,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -25,13 +26,15 @@ type managerImpl struct {
|
||||
store store.Store
|
||||
extraSettingsManager extra_settings.Manager
|
||||
userManager users.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager) Manager {
|
||||
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
extraSettingsManager: extraSettingsManager,
|
||||
userManager: userManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,13 +44,12 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
|
||||
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
if userID != activity.SystemInitiator {
|
||||
user, err := m.userManager.GetUser(ctx, userID)
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,8 +61,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -118,8 +118,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -180,8 +180,8 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -198,8 +198,8 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
@@ -226,8 +226,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
|
||||
@@ -183,7 +183,7 @@ func NewPermissionDeniedError() error {
|
||||
}
|
||||
|
||||
func NewPermissionValidationError(err error) error {
|
||||
return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err)
|
||||
return Errorf(PermissionDenied, "failed to validate user permissions: %s", err)
|
||||
}
|
||||
|
||||
func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
|
||||
|
||||
@@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
func (s *FileStore) GetStoreEngine() Engine {
|
||||
return FileStoreEngine
|
||||
func (s *FileStore) GetStoreEngine() types.Engine {
|
||||
return types.FileStoreEngine
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type SqlStore struct {
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine Engine
|
||||
storeEngine types.Engine
|
||||
}
|
||||
|
||||
type installation struct {
|
||||
@@ -66,7 +66,7 @@ type installation struct {
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
// NewSqlStore creates a new SqlStore instance.
|
||||
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
sql, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
|
||||
if storeEngine == SqliteStoreEngine {
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
if err == nil {
|
||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
|
||||
}
|
||||
|
||||
func GetKeyQueryCondition(s *SqlStore) string {
|
||||
if s.storeEngine == MysqlStoreEngine {
|
||||
if s.storeEngine == types.MysqlStoreEngine {
|
||||
return mysqlKeyQueryCondition
|
||||
}
|
||||
return keyQueryCondition
|
||||
@@ -586,6 +586,19 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get account owner from the store")
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
|
||||
var groups []*types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
||||
@@ -970,7 +983,7 @@ func (s *SqlStore) Close(_ context.Context) error {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns underlying store engine
|
||||
func (s *SqlStore) GetStoreEngine() Engine {
|
||||
func (s *SqlStore) GetStoreEngine() types.Engine {
|
||||
return s.storeEngine
|
||||
}
|
||||
|
||||
@@ -988,7 +1001,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics)
|
||||
return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
@@ -998,7 +1011,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
|
||||
return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewMysqlStore creates a new MySQL store.
|
||||
@@ -1008,7 +1021,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
|
||||
return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics)
|
||||
}
|
||||
|
||||
func getGormConfig() *gorm.Config {
|
||||
@@ -1517,9 +1530,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||
|
||||
switch s.storeEngine {
|
||||
case PostgresStoreEngine:
|
||||
case types.PostgresStoreEngine:
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
case MysqlStoreEngine:
|
||||
case types.MysqlStoreEngine:
|
||||
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
|
||||
default:
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
@@ -2194,3 +2207,17 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
||||
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.Model(&types.Account{}).
|
||||
Where("domain = ? AND domain_category = ?",
|
||||
strings.ToLower(domain), types.PrivateCategory,
|
||||
).Count(&count)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error)
|
||||
return 0, status.Errorf(status.Internal, "failed to count accounts by private domain")
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) {
|
||||
}
|
||||
|
||||
// TODO: figure out why this fails on postgres
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
@@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
@@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
@@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
return
|
||||
@@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
@@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -69,10 +69,12 @@ type Store interface {
|
||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
||||
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
@@ -164,7 +166,7 @@ type Store interface {
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return Engine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() Engine
|
||||
GetStoreEngine() types.Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
@@ -187,44 +189,37 @@ type Store interface {
|
||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||
}
|
||||
|
||||
type Engine string
|
||||
|
||||
const (
|
||||
FileStoreEngine Engine = "jsonfile"
|
||||
SqliteStoreEngine Engine = "sqlite"
|
||||
PostgresStoreEngine Engine = "postgres"
|
||||
MysqlStoreEngine Engine = "mysql"
|
||||
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
||||
)
|
||||
|
||||
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
|
||||
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
|
||||
|
||||
func getStoreEngineFromEnv() Engine {
|
||||
func getStoreEngineFromEnv() types.Engine {
|
||||
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
|
||||
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
value := Engine(strings.ToLower(kind))
|
||||
value := types.Engine(strings.ToLower(kind))
|
||||
if slices.Contains(supportedEngines, value) {
|
||||
return value
|
||||
}
|
||||
|
||||
return SqliteStoreEngine
|
||||
return types.SqliteStoreEngine
|
||||
}
|
||||
|
||||
// getStoreEngine determines the store engine to use.
|
||||
// If no engine is specified, it attempts to retrieve it from the environment.
|
||||
// If still not specified, it defaults to using SQLite.
|
||||
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
|
||||
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
|
||||
func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
|
||||
if kind == "" {
|
||||
kind = getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
kind = SqliteStoreEngine
|
||||
kind = types.SqliteStoreEngine
|
||||
|
||||
// Migrate if it is the first run with a JSON file existing and no SQLite file present
|
||||
jsonStoreFile := filepath.Join(dataDir, storeFileName)
|
||||
@@ -236,7 +231,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
|
||||
// Attempt to migrate from JSON store to SQLite
|
||||
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
|
||||
kind = FileStoreEngine
|
||||
kind = types.FileStoreEngine
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,7 +241,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
|
||||
}
|
||||
|
||||
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
|
||||
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
kind = getStoreEngine(ctx, dataDir, kind)
|
||||
|
||||
if err := checkFileStoreEngine(kind, dataDir); err != nil {
|
||||
@@ -254,13 +249,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case SqliteStoreEngine:
|
||||
case types.SqliteStoreEngine:
|
||||
log.WithContext(ctx).Info("using SQLite store engine")
|
||||
return NewSqliteStore(ctx, dataDir, metrics)
|
||||
case PostgresStoreEngine:
|
||||
case types.PostgresStoreEngine:
|
||||
log.WithContext(ctx).Info("using Postgres store engine")
|
||||
return newPostgresStore(ctx, metrics)
|
||||
case MysqlStoreEngine:
|
||||
case types.MysqlStoreEngine:
|
||||
log.WithContext(ctx).Info("using MySQL store engine")
|
||||
return newMysqlStore(ctx, metrics)
|
||||
default:
|
||||
@@ -268,12 +263,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileStoreEngine(kind Engine, dataDir string) error {
|
||||
if kind == FileStoreEngine {
|
||||
func checkFileStoreEngine(kind types.Engine, dataDir string) error {
|
||||
if kind == types.FileStoreEngine {
|
||||
storeFile := filepath.Join(dataDir, storeFileName)
|
||||
if util.FileExists(storeFile) {
|
||||
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
|
||||
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine)
|
||||
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -326,7 +321,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
||||
kind := getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
kind = SqliteStoreEngine
|
||||
kind = types.SqliteStoreEngine
|
||||
}
|
||||
|
||||
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
||||
@@ -348,7 +343,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
|
||||
}
|
||||
}
|
||||
|
||||
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
|
||||
store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
|
||||
}
|
||||
@@ -394,13 +389,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
|
||||
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
|
||||
var cleanup func()
|
||||
var err error
|
||||
switch kind {
|
||||
case PostgresStoreEngine:
|
||||
case types.PostgresStoreEngine:
|
||||
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
|
||||
case MysqlStoreEngine:
|
||||
case types.MysqlStoreEngine:
|
||||
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
|
||||
default:
|
||||
cleanup = func() {
|
||||
@@ -419,7 +414,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
|
||||
return store, closeConnection, nil
|
||||
}
|
||||
|
||||
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
||||
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
||||
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
|
||||
var err error
|
||||
_, err = testutil.CreatePostgresTestContainer()
|
||||
@@ -451,7 +446,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (
|
||||
return store, cleanup, nil
|
||||
}
|
||||
|
||||
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
||||
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
||||
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
|
||||
var err error
|
||||
_, err = testutil.CreateMysqlTestContainer()
|
||||
@@ -483,7 +478,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq
|
||||
return store, cleanup, nil
|
||||
}
|
||||
|
||||
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
|
||||
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
|
||||
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
|
||||
|
||||
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
|
||||
@@ -493,9 +488,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
|
||||
var err error
|
||||
cleanup := func() {
|
||||
switch engine {
|
||||
case PostgresStoreEngine:
|
||||
case types.PostgresStoreEngine:
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
|
||||
case MysqlStoreEngine:
|
||||
case types.MysqlStoreEngine:
|
||||
// err = killMySQLConnections(dsn, dbName)
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
|
||||
}
|
||||
|
||||
5
management/server/testdata/management.json
vendored
5
management/server/testdata/management.json
vendored
@@ -20,6 +20,11 @@
|
||||
"Secret": "c29tZV9wYXNzd29yZA==",
|
||||
"TimeBasedCredentials": true
|
||||
},
|
||||
"Relay":{
|
||||
"Addresses":["rel://test.com:3535"],
|
||||
"CredentialsTTL":"2h",
|
||||
"Secret":"netbird"
|
||||
},
|
||||
"Signal": {
|
||||
"Proto": "http",
|
||||
"URI": "signal.netbird.io:10000",
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||
|
||||
@@ -32,8 +33,8 @@ type SecretsManager interface {
|
||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||
type TimeBasedAuthSecretsManager struct {
|
||||
mux sync.Mutex
|
||||
turnCfg *TURNConfig
|
||||
relayCfg *Relay
|
||||
turnCfg *types.TURNConfig
|
||||
relayCfg *types.Relay
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager *PeersUpdateManager
|
||||
@@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
@@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
}
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, accountID, update)
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
@@ -245,17 +246,18 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
},
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, accountID, update)
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) {
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
|
||||
update.NetbirdConfig = extendedConfig
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var TurnTestHost = &Host{
|
||||
Proto: UDP,
|
||||
var TurnTestHost = &types.Host{
|
||||
Proto: types.UDP,
|
||||
URI: "turn:turn.netbird.io:77777",
|
||||
Username: "username",
|
||||
Password: "",
|
||||
@@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
|
||||
rc := &Relay{
|
||||
rc := &types.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
@@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
|
||||
@@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
peer := "some_peer"
|
||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
||||
|
||||
rc := &Relay{
|
||||
rc := &types.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
@@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
|
||||
@@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
|
||||
rc := &Relay{
|
||||
rc := &types.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
@@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
|
||||
|
||||
@@ -149,11 +149,6 @@ func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enab
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// currently we support only linux routing peers
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -30,6 +29,8 @@ const (
|
||||
DefaultDeviceAuthFlowScope string = "openid"
|
||||
)
|
||||
|
||||
var MgmtConfigPath string
|
||||
|
||||
// Config of the Management service
|
||||
type Config struct {
|
||||
Stuns []*Host
|
||||
@@ -76,6 +77,7 @@ type TURNConfig struct {
|
||||
Turns []*Host
|
||||
}
|
||||
|
||||
// Relay configuration type
|
||||
type Relay struct {
|
||||
Addresses []string
|
||||
CredentialsTTL util.Duration
|
||||
@@ -156,7 +158,7 @@ type ProviderConfig struct {
|
||||
|
||||
// StoreConfig contains Store configuration
|
||||
type StoreConfig struct {
|
||||
Engine store.Engine
|
||||
Engine Engine
|
||||
}
|
||||
|
||||
// ReverseProxy contains reverse proxy configuration in front of management.
|
||||
10
management/server/types/store.go
Normal file
10
management/server/types/store.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package types
|
||||
|
||||
type Engine string
|
||||
|
||||
const (
|
||||
PostgresStoreEngine Engine = "postgres"
|
||||
FileStoreEngine Engine = "jsonfile"
|
||||
SqliteStoreEngine Engine = "sqlite"
|
||||
MysqlStoreEngine Engine = "mysql"
|
||||
)
|
||||
@@ -30,8 +30,8 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
@@ -93,8 +93,8 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviterID := userID
|
||||
@@ -142,12 +142,21 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
|
||||
// createNewIdpUser validates the invite and creates a new user in the IdP
|
||||
func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
|
||||
inviter, err := am.GetUserByID(ctx, inviterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get inviter user: %w", err)
|
||||
}
|
||||
|
||||
// inviterUser is the one who is inviting the new user
|
||||
inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID)
|
||||
inviterUser, err := am.lookupUserInCache(ctx, inviterID, inviter.AccountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID)
|
||||
}
|
||||
|
||||
if inviterUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "inviter user with ID %s is empty", inviterID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this email => reject
|
||||
user, err := am.lookupUserInCacheByEmail(ctx, invite.Email, accountID)
|
||||
if err != nil {
|
||||
@@ -188,7 +197,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
|
||||
|
||||
err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
|
||||
if newLogin {
|
||||
@@ -228,8 +237,8 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
return err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
@@ -290,8 +299,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
return err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if the user is already registered with this ID
|
||||
@@ -338,8 +347,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
@@ -376,8 +385,8 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
return err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
|
||||
@@ -411,8 +420,8 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
|
||||
@@ -429,8 +438,8 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
|
||||
@@ -476,8 +485,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() {
|
||||
@@ -511,7 +520,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
ctx, transaction, groupsMap, initiatorUser, update, addIfNotExists, settings,
|
||||
ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process user update: %w", err)
|
||||
@@ -597,13 +606,13 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
|
||||
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||
accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||
|
||||
if update == nil {
|
||||
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, update, addIfNotExists)
|
||||
oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
@@ -614,7 +623,6 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
|
||||
// only auto groups, revoked status, and integration reference can be updated for now
|
||||
updatedUser := oldUser.Copy()
|
||||
updatedUser.AccountID = initiatorUser.AccountID
|
||||
updatedUser.Role = update.Role
|
||||
updatedUser.Blocked = update.Blocked
|
||||
updatedUser.AutoGroups = update.AutoGroups
|
||||
@@ -657,17 +665,23 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
|
||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) {
|
||||
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
|
||||
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
if !addIfNotExists {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
|
||||
}
|
||||
update.AccountID = accountID
|
||||
return update, nil // use all fields from update if addIfNotExists is true
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existingUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch")
|
||||
}
|
||||
|
||||
return existingUser, nil
|
||||
}
|
||||
|
||||
@@ -705,6 +719,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
|
||||
// @todo double check these
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
@@ -790,8 +805,8 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if initiatorUser.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers)
|
||||
@@ -967,6 +982,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
@@ -1081,6 +1100,25 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
return updateAccountPeers, nil
|
||||
}
|
||||
|
||||
// GetOwnerInfo retrieves the owner information for a given account ID.
|
||||
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
|
||||
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if owner == nil {
|
||||
return nil, status.Errorf(status.NotFound, "owner not found")
|
||||
}
|
||||
|
||||
userInfo, err := am.getUserInfo(ctx, owner, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
|
||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -59,9 +59,11 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
@@ -107,9 +109,11 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
@@ -133,9 +137,11 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
@@ -160,9 +166,11 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn)
|
||||
@@ -183,9 +191,11 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
|
||||
@@ -214,9 +224,11 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
|
||||
@@ -255,9 +267,11 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
|
||||
@@ -296,9 +310,11 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID)
|
||||
@@ -390,9 +406,11 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"})
|
||||
@@ -435,9 +453,11 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
@@ -481,9 +501,11 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
@@ -510,10 +532,12 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
|
||||
@@ -616,9 +640,11 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
|
||||
@@ -652,9 +678,11 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID)
|
||||
@@ -704,10 +732,12 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -812,10 +842,12 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -921,9 +953,11 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
claims := nbcontext.UserAuth{
|
||||
@@ -957,9 +991,11 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
users, err := am.ListUsers(context.Background(), mockAccountID)
|
||||
@@ -1044,9 +1080,11 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
users, err := am.ListUsers(context.Background(), mockAccountID)
|
||||
@@ -1087,11 +1125,13 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
|
||||
@@ -1148,9 +1188,11 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
@@ -1180,9 +1222,11 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsMananagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsMananagerMock,
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID)
|
||||
@@ -1525,3 +1569,41 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
AccountID: account1.Id,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManagerMock := permissions.NewManagerMock()
|
||||
am := DefaultAccountManager{
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: nil,
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
permissionsManager: permissionsManagerMock,
|
||||
}
|
||||
|
||||
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
|
||||
assert.Error(t, err, "update user to another account should fail")
|
||||
|
||||
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account1.Users[targetId].Id, user.Id)
|
||||
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
|
||||
assert.Equal(t, account1.Users[targetId].AutoGroups, user.AutoGroups)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user