Compare commits

...

16 Commits

Author SHA1 Message Date
Pascal Fischer
e4fea16c9f fix test expectation 2025-04-02 17:10:14 +02:00
Pascal Fischer
b3ab142d2e allow admins to fetch users 2025-04-02 16:20:34 +02:00
Pascal Fischer
de457788ba return error when trying to use accountID path variable with PAT 2025-04-02 16:03:33 +02:00
Viktor Liu
09243a0fe0 [management] Remove remaining backend linux router limitation (#3589) 2025-04-01 21:29:57 +02:00
Maycon Santos
3658215747 [client] Force new user login on PKCE auth in CLI (#3604)
With this change, browser session won't be considered for cli authentication and credentials will be requested
2025-04-01 10:29:29 +02:00
Viktor Liu
48ffec95dd Improve local ip lookup (#3551)
- lower memory footprint in most cases
- increase accuracy
2025-03-31 10:05:57 +02:00
Pedro Maia Costa
cbec7bda80 [management] permission manager validate account access (#3444) 2025-03-30 17:08:22 +02:00
Zoltan Papp
21464ac770 [client] Fix close WireGuard watcher (#3598)
This PR fixes issues with closing the WireGuard watcher by adjusting its asynchronous invocation and synchronization.

Update tests in wg_watcher_test.go to launch the watcher in a goroutine and add a delay for timing.
Modify wg_watcher.go to run the periodic handshake check synchronously by removing the waitGroup and goroutine.
Enhance conn.go to wait on the watcher wait group during connection close and add a note for potential further synchronization
2025-03-28 20:12:31 +01:00
Zoltan Papp
ed5647028a [client] Prevent calling the onDisconnected callback in incorrect state (#3582)
Prevent calling the onDisconnected callback if the ICE connection has never been established

If call onDisconnected without onConnected then overwrite the relayed status in the conn priority variable.
2025-03-28 18:08:26 +01:00
Viktor Liu
29a6e5be71 [client] Stop flow grpc receiver properly (#3596) 2025-03-28 16:08:31 +01:00
Viktor Liu
6124e3b937 [client] Disable systemd-resolved default route explicitly on match domains only (#3584) 2025-03-28 11:14:32 +01:00
Maycon Santos
50f5cc48cd [management] Fix extended config when nil (#3593)
* Fix extended config when nil

* update integrations
2025-03-27 23:07:10 +01:00
Viktor Liu
101cce27f2 [client] Ensure status recorder is always initialized (#3588)
* Ensure status recorder is always initialized

* Add test

* Add subscribe test
2025-03-27 22:48:11 +01:00
Maycon Santos
a4f04f5570 [management] fix extend call and move config to types (#3575)
This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
2025-03-27 13:04:50 +01:00
hakansa
fceb3ca392 [client] fix route handling for local peer state (#3586) 2025-03-27 19:31:04 +08:00
Bethuel Mmbaga
34d86c5ab8 [management] Sync account peers on network router group changes (#3573)
- Updates account peers when a group linked to a network router is modified
- Prevents group deletion if it's still being used by any network router
2025-03-27 12:19:22 +01:00
76 changed files with 1858 additions and 617 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -12,9 +12,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@@ -32,7 +34,7 @@ import (
func startTestingServices(t *testing.T) string {
t.Helper()
config := &mgmt.Config{}
config := &types.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -67,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -90,13 +92,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,7 +1,6 @@
package conntrack
import (
"context"
"net/netip"
"testing"
@@ -12,7 +11,7 @@ import (
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {

View File

@@ -14,8 +14,13 @@ import (
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
// fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [256]*ipv4LowBitmap
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
}
func newLocalIPManager() *localIPManager {
@@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
ipStr := ip.String()
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
}()
var newIPv4Bitmap [1 << 16]uint32
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
if iface != nil {
@@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
return m.checkBitmapBit(ip.AsSlice())
}

View File

@@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) {
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
@@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
@@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
@@ -24,7 +23,7 @@ import (
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error

View File

@@ -1,7 +1,6 @@
package acl
import (
"context"
"net"
"testing"
@@ -15,7 +14,7 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{

View File

@@ -99,6 +99,7 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
oauth2.SetAuthURLParam("prompt", "login"),
)
return AuthFlowInfo{

View File

@@ -31,7 +31,7 @@ import (
"github.com/netbirdio/netbird/formatter"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type mocWGIface struct {
filter device.PacketFilter

View File

@@ -38,7 +38,6 @@ const (
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
}
@@ -124,18 +123,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
}
if config.RouteAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
return fmt.Errorf("set link as default dns router: %w", err)
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: nbdns.RootZone,
MatchOnly: true,
})
s.routingAll = true
} else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} else {
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
return fmt.Errorf("remove link as default dns router: %w", err)
}
}
state := &ShutdownState{

View File

@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")

View File

@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -1400,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Relay: &server.Relay{
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Relay: &types.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &server.Host{
Signal: &types.Host{
Proto: "http",
URI: "localhost:10000",
},
@@ -1438,6 +1439,8 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -1446,7 +1449,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
return nil, "", err
}

View File

@@ -19,11 +19,9 @@ import (
type rcvChan chan *types.EventFields
type Logger struct {
mux sync.Mutex
ctx context.Context
cancel context.CancelFunc
enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc
cancel context.CancelFunc
statusRecorder *peer.Status
wgIfaceIPNet net.IPNet
dnsCollection atomic.Bool
@@ -31,12 +29,9 @@ type Logger struct {
Store types.Store
}
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
ctx, cancel := context.WithCancel(ctx)
return &Logger{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet,
Store: store.NewMemoryStore(),
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
}
l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.mux.Unlock()
c := make(rcvChan, 100)
@@ -109,7 +104,7 @@ func (l *Logger) startReceiver() {
}
}
func (l *Logger) Disable() {
func (l *Logger) Close() {
l.stop()
l.Store.Close()
}
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
l.enabled.Store(false)
l.mux.Lock()
if l.cancelReceiver != nil {
l.cancelReceiver()
l.cancelReceiver = nil
if l.cancel != nil {
l.cancel()
l.cancel = nil
}
l.rcvChan.Store(nil)
l.mux.Unlock()
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
l.exitNodeCollection.Store(exitNodeCollection)
}
func (l *Logger) Close() {
l.stop()
l.cancel()
}
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {

View File

@@ -1,7 +1,6 @@
package logger_test
import (
"context"
"net"
"testing"
"time"
@@ -13,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(context.Background(), nil, net.IPNet{})
logger := logger.New(nil, net.IPNet{})
logger.Enable()
event := types.EventFields{
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
}
// test disable
logger.Disable()
logger.Close()
wait()
logger.StoreEvent(event)
wait()

View File

@@ -27,18 +27,18 @@ type Manager struct {
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
ctx context.Context
receiverClient *client.GRPCClient
publicKey []byte
cancel context.CancelFunc
}
// NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet
if iface != nil {
ipNet = *iface.Address().Network
}
flowLogger := logger.New(ctx, statusRecorder, ipNet)
flowLogger := logger.New(statusRecorder, ipNet)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
return &Manager{
logger: flowLogger,
conntrack: ct,
ctx: ctx,
publicKey: publicKey,
}
}
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %v", err)
}
if err := m.resetClient(); err != nil {
return fmt.Errorf("reset client: %w", err)
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
return nil
}
func (m *Manager) resetClient() error {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %v", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
if m.cancel != nil {
m.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.cancel != nil {
m.cancel()
}
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
m.logger.Close()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
changed := previous != nil && update.Enabled != previous.Enabled
if update.Enabled {
log.Infof("netflow manager enabled; starting netflow manager")
if changed {
log.Infof("netflow manager enabled; starting netflow manager")
}
return m.enableFlow(previous)
}
log.Infof("netflow manager disabled; stopping netflow manager")
err := m.disableFlow()
if err != nil {
log.Errorf("failed to disable netflow manager: %v", err)
if changed {
log.Infof("netflow manager disabled; stopping netflow manager")
}
return err
return m.disableFlow()
}
// Close cleans up all resources
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if m.conntrack != nil {
m.conntrack.Close()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %v", err)
}
}
m.logger.Close()
}
// GetLogger returns the flow logger
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger
}
func (m *Manager) startSender() {
func (m *Manager) startSender(ctx context.Context) {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
case <-ctx.Done():
return
case <-ticker.C:
events := m.logger.GetEvents()
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
}
}
func (m *Manager) receiveACKs(client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
id, err := uuid.FromBytes(ack.EventId)
if err != nil {
log.Warnf("failed to convert ack event id to uuid: %v", err)

View File

@@ -0,0 +1,200 @@
package netflow
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
)
type mockIFaceMapper struct {
address wgaddr.Address
isUserspaceBind bool
}
func (m *mockIFaceMapper) Name() string {
return "wt0"
}
func (m *mockIFaceMapper) Address() wgaddr.Address {
return m.address
}
func (m *mockIFaceMapper) IsUserspaceBind() bool {
return m.isUserspaceBind
}
func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
statusRecorder := peer.NewRecorder("")
manager := NewManager(mockIFace, publicKey, statusRecorder)
tests := []struct {
name string
config *types.FlowConfig
}{
{
name: "nil config",
config: nil,
},
{
name: "disabled config",
config: &types.FlowConfig{
Enabled: false,
},
},
{
name: "enabled config with minimal valid settings",
config: &types.FlowConfig{
Enabled: true,
URL: "https://example.com",
TokenPayload: "test-payload",
TokenSignature: "test-signature",
Interval: 30 * time.Second,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := manager.Update(tc.config)
assert.NoError(t, err)
if tc.config == nil {
return
}
require.NotNil(t, manager.flowConfig)
if tc.config.Enabled {
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
}
if tc.config.URL != "" {
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
}
if tc.config.TokenPayload != "" {
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
}
})
}
}
func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
// First update with tokens
initialConfig := &types.FlowConfig{
Enabled: false,
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
}
err := manager.Update(initialConfig)
require.NoError(t, err)
// Second update without tokens should preserve them
updatedConfig := &types.FlowConfig{
Enabled: false,
URL: "https://example.com",
}
err = manager.Update(updatedConfig)
require.NoError(t, err)
// Verify tokens were preserved
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
}
func TestManager_NeedsNewClient(t *testing.T) {
manager := &Manager{}
tests := []struct {
name string
previous *types.FlowConfig
current *types.FlowConfig
expected bool
}{
{
name: "nil previous config",
previous: nil,
current: &types.FlowConfig{},
expected: true,
},
{
name: "previous disabled",
previous: &types.FlowConfig{Enabled: false},
current: &types.FlowConfig{Enabled: true},
expected: true,
},
{
name: "different URL",
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
expected: true,
},
{
name: "different TokenPayload",
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
expected: true,
},
{
name: "different TokenSignature",
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
expected: true,
},
{
name: "same config",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
expected: false,
},
{
name: "only interval changed",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
manager.flowConfig = tc.current
result := manager.needsNewClient(tc.previous)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -120,9 +120,6 @@ type FlowLogger interface {
Close()
// Enable enables the flow logger receiver
Enable()
// Disable disables the flow logger receiver
Disable()
// UpdateConfig updates the flow manager configuration
UpdateConfig(dnsCollection, exitNodeCollection bool)
}

View File

@@ -103,6 +103,7 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
@@ -211,6 +212,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
// Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close() {
conn.mu.Lock()
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
@@ -252,6 +254,7 @@ func (conn *Conn) Close() {
}
conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed")
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
@@ -362,6 +365,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
@@ -407,7 +411,12 @@ func (conn *Conn) onICEStateDisconnected() {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.currentConnPriority = connPriorityRelay
} else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
@@ -476,7 +485,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
@@ -502,6 +516,7 @@ func (conn *Conn) onRelayDisconnected() {
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = connPriorityNone
}
if conn.wgProxyRelay != nil {

View File

@@ -586,9 +586,8 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
if err == nil {
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
if d.localPeer.Routes == nil {
@@ -596,8 +595,6 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
}
d.localPeer.Routes[route] = struct{}{}
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
// RemoveLocalPeerStateRoute removes a route from the local peer state
@@ -606,14 +603,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
if err == nil {
d.routeIDLookup.RemoveLocalRouteID(pref)
}
delete(d.localPeer.Routes, route)
d.routeIDLookup.RemoveLocalRouteID(pref)
}
// CleanLocalPeerStateRoutes cleans all routes from the local peer state

View File

@@ -32,7 +32,6 @@ type WGWatcher struct {
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
waitGroup sync.WaitGroup
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.waitGroup.Add(1)
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
@@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
w.ctxCancel()
w.ctxCancel = nil
w.waitGroup.Wait()
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
defer w.waitGroup.Done()
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()

View File

@@ -49,7 +49,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
defer cancel()
onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() {
go watcher.EnableWgWatcher(ctx, func() {
mlog.Infof("onDisconnectedFn")
onDisconnected <- struct{}{}
})
@@ -79,10 +79,11 @@ func TestWGWatcher_ReEnable(t *testing.T) {
onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() {})
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
watcher.EnableWgWatcher(ctx, func() {
go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})

View File

@@ -65,6 +65,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
lastKnownState: ice.ConnectionStateDisconnected,
}
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -213,7 +214,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}

View File

@@ -84,6 +84,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
},
logFile: logFile,
persistNetworkMap: true,
statusRecorder: peer.NewRecorder(""),
}
}
@@ -136,9 +137,6 @@ func (s *Server) Start() error {
s.config = config
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
@@ -622,9 +620,6 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
return nil, fmt.Errorf("config is not defined, please call login command first")
}
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
@@ -692,9 +687,6 @@ func (s *Server) Status(
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)

View File

@@ -3,27 +3,32 @@ package server
import (
"context"
"net"
"net/url"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@@ -83,6 +88,72 @@ func TestConnectWithRetryRuns(t *testing.T) {
}
}
func TestServer_Up(t *testing.T) {
ctx := internal.CtxInitState(context.Background())
s := New(ctx, t.TempDir()+"/config.json", "console")
err := s.Start()
require.NoError(t, err)
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
s.config = &internal.Config{
ManagementURL: u,
}
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
upReq := &daemonProto.UpRequest{}
_, err = s.Up(upCtx, upReq)
assert.Contains(t, err.Error(), "NeedsLogin")
}
type mockSubscribeEventsServer struct {
ctx context.Context
sentEvents []*daemonProto.SystemEvent
grpc.ServerStream
}
func (m *mockSubscribeEventsServer) Send(event *daemonProto.SystemEvent) error {
m.sentEvents = append(m.sentEvents, event)
return nil
}
func (m *mockSubscribeEventsServer) Context() context.Context {
return m.ctx
}
func TestServer_SubcribeEvents(t *testing.T) {
ctx := internal.CtxInitState(context.Background())
s := New(ctx, t.TempDir()+"/config.json", "console")
err := s.Start()
require.NoError(t, err)
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
s.config = &internal.Config{
ManagementURL: u,
}
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
upReq := &daemonProto.SubscribeRequest{}
mockServer := &mockSubscribeEventsServer{
ctx: ctx,
sentEvents: make([]*daemonProto.SystemEvent, 0),
ServerStream: nil,
}
err = s.SubscribeEvents(upReq, mockServer)
assert.NoError(t, err)
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
@@ -97,10 +168,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper()
dataDir := t.TempDir()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Signal: &types.Host{
Proto: "http",
URI: signalAddr,
},
@@ -129,11 +200,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
return nil, "", err
}

View File

@@ -13,10 +13,12 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
@@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error {
defer c.streamMu.Unlock()
c.stream = nil
return c.clientConn.Close()
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("close client connection: %w", err)
}
return nil
}
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
backOff := defaultBackoff(ctx, interval)
operation := func() error {
err := c.establishStreamAndReceive(ctx, msgHandler)
if err != nil {
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
}
log.Errorf("receive failed: %v", err)
return fmt.Errorf("receive: %w", err)
}
return err
return nil
}
if err := backoff.Retry(operation, backOff); err != nil {

256
flow/client/client_test.go Normal file
View File

@@ -0,0 +1,256 @@
package client_test
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
ack := &proto.FlowEventAck{
EventId: event.EventId,
}
select {
case s.acks <- ack:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestReceive(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
receivedAcks := make(map[string]bool)
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
close(receiveDone)
}
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
// Create acknowledgment and send it to the flow through our test server
ack := &proto.FlowEventAck{
EventId: []byte(eventID),
}
select {
case server.acks <- ack:
case <-time.After(time.Second):
t.Fatal("timeout sending ack")
}
}
select {
case <-receiveDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, 3, len(receivedAcks))
}
func TestReceive_ContextCancellation(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
handlerCalled := false
msgHandler := func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator {
handlerCalled = true
}
return nil
}
err = client.Receive(ctx, 1*time.Second, msgHandler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
assert.False(t, handlerCalled)
}
func TestSend(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
ackReceived := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
if len(ack.EventId) > 0 && !ack.IsInitiator {
close(ackReceived)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
testEvent := &proto.FlowEvent{
EventId: []byte("test-event-id"),
PublicKey: []byte("test-public-key"),
FlowFields: &proto.FlowFields{
FlowId: []byte("test-flow-id"),
Protocol: 6,
SourceIp: []byte{192, 168, 1, 1},
DestIp: []byte{192, 168, 1, 2},
ConnectionInfo: &proto.FlowFields_PortInfo{
PortInfo: &proto.PortInfo{
SourcePort: 12345,
DestPort: 443,
},
},
},
}
err = client.Send(testEvent)
require.NoError(t, err)
var receivedEvent *proto.FlowEvent
select {
case receivedEvent = <-server.events:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for event to be received by server")
}
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
select {
case <-ackReceived:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ack to be received by flow")
}
}

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 h1:uxxbLPXQgC9VO15epNPtrD6zazyd5rZeqC5hQSmCdZU=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -50,7 +51,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &mgmt.Config{}
config := &types.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -74,6 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -87,7 +89,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
t.Fatal(err)
}

View File

@@ -34,7 +34,9 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
@@ -50,7 +52,6 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -70,7 +71,7 @@ var (
mgmtSingleAccModeDomain string
certFile string
certKey string
config *server.Config
config *types.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -101,9 +102,9 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, mgmtConfig)
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err)
}
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@@ -183,7 +184,7 @@ var (
if config.DataStoreEncryptionKey != key {
log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, mgmtConfig, config)
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
@@ -201,15 +202,14 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
userManager := users.NewManager(store)
extraSettingsManager := integrations.NewManager(eventStore)
settingsManager := settings.NewManager(store, userManager, extraSettingsManager)
permissionsManager := permissions.NewManager(userManager, settingsManager)
settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
proxyController := integrations.NewController(store)
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager)
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}
@@ -486,8 +486,8 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &types.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
@@ -522,7 +522,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@@ -539,7 +539,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope
}
}
@@ -560,7 +560,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
return loadedConfig, err
}
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
@@ -636,7 +636,7 @@ func handleRebrand(cmd *cobra.Command) error {
}
}
}
if mgmtConfig == defaultMgmtConfig {
if types.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@@ -7,6 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/version"
)
@@ -19,7 +20,6 @@ const (
var (
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
@@ -56,7 +56,7 @@ func init() {
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")

View File

@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status"
@@ -89,6 +90,8 @@ type DefaultAccountManager struct {
integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
permissionsManager permissions.Manager
}
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
@@ -156,6 +159,7 @@ func BuildManager(
metrics telemetry.AppMetrics,
proxyController port_forwarding.Controller,
settingsManager settings.Manager,
permissionsManager permissions.Manager,
) (*DefaultAccountManager, error) {
start := time.Now()
defer func() {
@@ -180,6 +184,7 @@ func BuildManager(
requestBuffer: NewAccountRequestBuffer(ctx, store),
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
}
accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil {
@@ -253,13 +258,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
user, err := account.FindUser(userID)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Write)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
if !allowed {
return nil, status.NewPermissionDeniedError()
}
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
@@ -503,16 +508,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
user, err := account.FindUser(userID)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Accounts, permissions.Write)
if err != nil {
return err
return fmt.Errorf("failed to validate user permissions: %w", err)
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
}
if user.Role != types.UserRoleOwner {
if !allowed {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
@@ -542,14 +543,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
}
userInfo, ok := userInfosMap[userID]
if !ok {
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
}
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
if err != nil {
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
return err
if ok {
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
if err != nil {
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
return err
}
}
err = am.Store.DeleteAccount(ctx, account)
@@ -1027,8 +1026,8 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
return nil, err
}
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
return am.Store.GetAccount(ctx, accountID)
@@ -1061,8 +1060,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return "", "", err
}
if !user.IsServiceUser && userAuth.Invited {
@@ -1521,7 +1520,11 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
@@ -1606,3 +1609,113 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma
func (am *DefaultAccountManager) GetStore() store.Store {
return am.Store
}
// Creates account by private domain.
// Expects domain value to be a valid and a private dns domain.
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()
domain = strings.ToLower(domain)
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
if err != nil {
return nil, err
}
if count > 0 {
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
}
// retry twice for new ID clashes
for range 2 {
accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
if err != nil || exists {
continue
}
network := types.NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
dnsSettings := types.DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
newAccount := &types.Account{
Id: accountId,
CreatedAt: time.Now().UTC(),
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
// @todo check if using the MSP owner id here is ok
CreatedBy: initiatorId,
Domain: domain,
DomainCategory: types.PrivateCategory,
IsDomainPrimaryAccount: false,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
},
}
if err := newAccount.AddAllGroup(); err != nil {
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
}
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
return nil, err
}
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
}
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
}
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return nil, err
}
if account.IsDomainPrimaryAccount {
return account, nil
}
// additional check to ensure there is only one account for this domain at the time of update
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
if err != nil {
return nil, err
}
if count > 1 {
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
}
return account, nil
}

View File

@@ -111,4 +111,7 @@ type Manager interface {
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
}

View File

@@ -17,6 +17,7 @@ import (
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/util"
@@ -2793,15 +2794,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
type TB interface {
Cleanup(func())
Helper()
TempDir() string
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
}
//type TB interface {
// Cleanup(func())
// Helper()
// TempDir() string
// Errorf(format string, args ...interface{})
// Fatalf(format string, args ...interface{})
//}
func createManager(t TB) (*DefaultAccountManager, error) {
func createManager(t testing.TB) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@@ -2815,6 +2816,8 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return nil, err
}
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -2828,7 +2831,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
Return(false, nil).
AnyTimes()
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
return nil, err
}
@@ -2836,7 +2839,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t TB) (store.Store, error) {
func createStore(t testing.TB) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
@@ -3150,3 +3153,51 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
})
}
}
func Test_CreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
assert.Equal(t, initiatorId, account.CreatedBy)
assert.Equal(t, 1, len(account.Groups))
assert.Equal(t, 0, len(account.Users))
assert.Equal(t, 0, len(account.SetupKeys))
// retry should fail
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.Error(t, err)
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, account.IsDomainPrimaryAccount)
// retry should fail
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
}

View File

@@ -67,8 +67,8 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -89,8 +89,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if !user.HasAdminPower() {

View File

@@ -13,6 +13,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -210,13 +211,14 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
}
func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -10,6 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func isEnabled() bool {
@@ -19,16 +21,12 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
@@ -58,6 +56,11 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
filtered = append(filtered, event)
}
err = am.fillEventsWithUserInfo(ctx, events, accountID, user)
if err != nil {
return nil, err
}
return filtered, nil
}
@@ -79,3 +82,156 @@ func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, ta
}()
}
}
type eventUserInfo struct {
email string
name string
accountId string
}
func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) error {
eventUserInfo, err := am.getEventsUserInfo(ctx, events, accountId, user)
if err != nil {
return err
}
for _, event := range events {
if !fillEventInitiatorInfo(eventUserInfo, event) {
log.WithContext(ctx).Warnf("failed to resolve user info for initiator: %s", event.InitiatorID)
}
fillEventTargetInfo(eventUserInfo, event)
}
return nil
}
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
if err != nil {
return nil, err
}
// @note check whether using a external initiator user here is an issue
userInfos, err := am.BuildUserInfosForAccount(ctx, accountId, user.Id, accountUsers)
if err != nil {
return nil, err
}
eventUserInfos := make(map[string]eventUserInfo)
for i, k := range userInfos {
eventUserInfos[i] = eventUserInfo{
email: k.Email,
name: k.Name,
accountId: accountId,
}
}
externalUserIds := []string{}
for _, event := range events {
if _, ok := eventUserInfos[event.InitiatorID]; ok {
continue
}
if event.InitiatorID == activity.SystemInitiator ||
event.InitiatorID == accountId ||
event.Activity == activity.PeerAddedWithSetupKey {
// @todo other events to be excluded if never initiated by a user
continue
}
externalUserIds = append(externalUserIds, event.InitiatorID)
}
if len(externalUserIds) == 0 {
return eventUserInfos, nil
}
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, user)
}
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, user *types.User) (map[string]eventUserInfo, error) {
externalAccountId := ""
fetched := make(map[string]struct{})
externalUsers := []*types.User{}
for _, id := range externalUserIds {
if _, ok := fetched[id]; ok {
continue
}
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
if err != nil {
// @todo consider logging
continue
}
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
return nil, fmt.Errorf("multiple external user accounts in events")
}
if externalAccountId == "" {
externalAccountId = externalUser.AccountID
}
fetched[id] = struct{}{}
externalUsers = append(externalUsers, externalUser)
}
// if we couldn't determine an account, return what we have
if externalAccountId == "" {
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
return eventUserInfos, nil
}
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, user.Id, externalUsers)
if err != nil {
return nil, err
}
for i, k := range externalUserInfos {
eventUserInfos[i] = eventUserInfo{
email: k.Email,
name: k.Name,
accountId: externalAccountId,
}
}
return eventUserInfos, nil
}
func fillEventTargetInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) {
userInfo, ok := eventUserInfo[event.TargetID]
if !ok {
return
}
if event.Meta == nil {
event.Meta = make(map[string]any)
}
event.Meta["email"] = userInfo.email
event.Meta["username"] = userInfo.name
}
func fillEventInitiatorInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) bool {
userInfo, ok := eventUserInfo[event.InitiatorID]
if !ok {
return false
}
if event.InitiatorEmail == "" {
event.InitiatorEmail = userInfo.email
}
if event.InitiatorName == "" {
event.InitiatorName = userInfo.name
}
if event.AccountID != userInfo.accountId {
if event.Meta == nil {
event.Meta = make(map[string]any)
}
event.Meta["external"] = true
}
return true
}

View File

@@ -10,13 +10,13 @@ import (
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
)
type GroupLinkError struct {
@@ -35,8 +35,8 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if user.IsRegularUser() {
@@ -83,8 +83,8 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if user.IsRegularUser() {
@@ -215,8 +215,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if user.IsRegularUser() {
@@ -498,6 +498,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"user", linkedUser.Id}
}
if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"network router", linkedRouter.ID}
}
return checkGroupLinkedToSettings(ctx, transaction, group)
}
@@ -613,6 +617,22 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
return false, nil
}
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
}
for _, router := range routers {
if slices.Contains(router.PeerGroups, groupID) {
return true, router
}
}
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
@@ -637,6 +657,9 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil

View File

@@ -12,6 +12,13 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
@@ -414,6 +421,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Name: "GroupD",
Peers: []string{},
},
{
ID: "groupE",
Name: "GroupE",
Peers: []string{peer2.ID},
},
})
assert.NoError(t, err)
@@ -673,4 +685,51 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test",
AccountID: account.Id,
Name: "network_test",
Description: "",
})
require.NoError(t, err)
_, err = routersManager.CreateRouter(context.Background(), userID, &routerTypes.NetworkRouter{
ID: "router_test",
NetworkID: network.ID,
AccountID: account.Id,
PeerGroups: []string{"groupE"},
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupE",
Name: "GroupE",
Peers: []string{peer2.ID, peer3.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account"
@@ -40,7 +41,7 @@ type GRPCServer struct {
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
config *types.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
@@ -51,7 +52,7 @@ type GRPCServer struct {
// NewServer creates a new Management server
func NewServer(
ctx context.Context,
config *Config,
config *types.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager,
@@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
case types.UDP:
return proto.HostConfig_UDP
case DTLS:
case types.DTLS:
return proto.HostConfig_DTLS
case HTTP:
case types.HTTP:
return proto.HostConfig_HTTP
case HTTPS:
case types.HTTPS:
return proto.HostConfig_HTTPS
case TCP:
case types.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
@@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token,
Relay: relayCfg,
}
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
return nbConfig
}
@@ -626,10 +625,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
}
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
response := &proto.SyncResponse{
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
@@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
@@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}

View File

@@ -1,7 +1,6 @@
package events
import (
"context"
"fmt"
"net/http"
@@ -47,66 +46,15 @@ func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
events := make([]*api.Event, len(accountEvents))
for i, e := range accountEvents {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, events)
}
func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil {
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
return err
}
emails := make(map[string]string)
names := make(map[string]string)
for _, ui := range userInfos {
emails[ui.ID] = ui.Email
names[ui.ID] = ui.Name
}
var ok bool
for _, event := range events {
// fill initiator
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}
if event.InitiatorName == "" {
// here to allowed to be empty because in the first release we did not store the name
event.InitiatorName = names[event.InitiatorId]
}
// fill target meta
email, ok := emails[event.TargetId]
if !ok {
continue
}
event.Meta["email"] = email
username, ok := names[event.TargetId]
if !ok {
continue
}
event.Meta["username"] = username
}
return nil
}
func toEventResponse(event *activity.Event) *api.Event {
meta := make(map[string]string)
if event.Meta != nil {

View File

@@ -250,7 +250,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
user, err := account.FindUser(userID)
user, err := h.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -258,7 +258,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)

View File

@@ -122,6 +122,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
}
return p, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
switch id {
case adminUser:
return account.Users[adminUser], nil
case regularUser:
return account.Users[regularUser], nil
case serviceUser:
return account.Users[serviceUser], nil
default:
return nil, fmt.Errorf("user not found")
}
},
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
return peers, nil
},

View File

@@ -301,7 +301,7 @@ func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -142,6 +142,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, fmt.Errorf("token expired")
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if user.AccountID != impersonate[0] {
return r, fmt.Errorf("token is not valid for this account")
}
}
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
return r, err

View File

@@ -239,16 +239,9 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
},
{
name: "Valid PAT Token ignores child",
name: "Valid PAT Token returns forbidden for child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsPAT: true,
},
},
{
name: "Valid JWT Token",

View File

@@ -16,19 +16,18 @@ import (
"github.com/golang-jwt/jwt"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -124,8 +123,9 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
validatorMock := server.MocIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}))
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager)
permissionsManagerMock := permissions.NewManagerMock()
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManagerMock)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManagerMock)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
@@ -143,7 +143,6 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
permissionsManagerMock := permissions.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManagerMock)
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManagerMock, peersManager, settingsManager)

View File

@@ -25,6 +25,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -93,21 +94,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -330,7 +331,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct {
name string
inputFlow *DeviceAuthorizationFlow
inputFlow *types.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
@@ -345,9 +346,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@@ -356,9 +357,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Full Device Flow Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@@ -379,7 +380,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &Config{
config: &types.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
@@ -410,7 +411,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@@ -431,6 +432,8 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -441,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
Return(&types.Settings{}, nil)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
cleanup()
@@ -506,21 +509,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip()
dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -678,21 +681,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -58,7 +59,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err)
}
config := &server.Config{}
config := &types.Config{}
_, err = util.ReadJson("testdata/management.json", config)
if err != nil {
t.Fatalf("failed to read management.json: %v", err)
@@ -156,7 +157,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer(
t *testing.T,
config *server.Config,
config *types.Config,
dataDir string,
testFile string,
) (*grpc.Server, net.Listener) {
@@ -194,6 +195,7 @@ func startServer(
Return(&types.Settings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewManagerMock()
accountManager, err := server.BuildManager(
context.Background(),
str,
@@ -208,6 +210,7 @@ func startServer(
metrics,
port_forwarding.NewControllerMock(),
settingsMockManager,
permissionsManagerMock,
)
if err != nil {
t.Fatalf("failed creating an account manager: %v", err)
@@ -300,6 +303,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
Protocol: mgmtProto.HostConfig_UDP,
}
expectedRelayHost := &mgmtProto.RelayConfig{
Urls: []string{"rel://test.com:3535"},
}
assert.NotNil(t, resp.NetbirdConfig)
assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig)
assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig)
@@ -307,6 +314,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
actualTURN := resp.NetbirdConfig.Turns[0]
assert.Greater(t, len(actualTURN.User), 0)
assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost)
assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1)
assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls)
assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0)
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@@ -49,7 +48,7 @@ type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() store.Engine
GetStoreEngine() types.Engine
}
// ConnManager peer connection manager that holds state for current active connections

View File

@@ -10,7 +10,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}
// GetStoreEngine returns FileStoreEngine
func (mockDatasource) GetStoreEngine() store.Engine {
return store.FileStoreEngine
func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
@@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
}
if properties["store_engine"] != store.FileStoreEngine {
if properties["store_engine"] != types.FileStoreEngine {
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
}

View File

@@ -112,6 +112,9 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@@ -847,3 +850,24 @@ func (am *MockAccountManager) GetStore() store.Store {
}
return nil
}
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
if am.CreateAccountByPrivateDomainFunc != nil {
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
}
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
if am.UpdateToPrimaryAccountFunc != nil {
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
}
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {
if am.GetOwnerInfoFunc != nil {
return am.GetOwnerInfoFunc(ctx, accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
}

View File

@@ -25,8 +25,8 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -46,8 +46,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
newNSGroup := &nbdns.NameServerGroup{
@@ -108,8 +108,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
var updateAccountPeers bool
@@ -159,8 +159,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
var nsGroup *nbdns.NameServerGroup
@@ -203,8 +203,8 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -774,11 +775,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
}
func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -37,8 +37,8 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
@@ -188,8 +188,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
var peer *nbpeer.Peer
@@ -321,8 +321,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
}
@@ -621,7 +621,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
@@ -1054,7 +1054,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin())
if err != nil {
return err
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
@@ -1099,8 +1099,8 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
@@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
@@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}

View File

@@ -20,9 +20,10 @@ import (
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/util"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -723,7 +724,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
@@ -998,6 +999,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 1},
{"Medium", 500, 1},
{"Large", 1000, 1},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
}
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
@@ -1008,16 +1056,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
config := &types.Config{
Signal: &types.Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
TURNConfig: &types.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
@@ -1217,7 +1265,8 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
permissionsManagerMock := permissions.NewManagerMock()
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1285,7 +1334,8 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
permissionsManagerMock := permissions.NewManagerMock()
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1356,7 +1406,8 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
permissionsManagerMock := permissions.NewManagerMock()
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

View File

@@ -5,10 +5,9 @@ import (
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
)
type Module string
@@ -17,6 +16,8 @@ const (
Networks Module = "networks"
Peers Module = "peers"
Groups Module = "groups"
Settings Module = "settings"
Accounts Module = "accounts"
)
type Operation string
@@ -28,42 +29,50 @@ const (
type Manager interface {
ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error)
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
}
type managerImpl struct {
userManager users.Manager
settingsManager settings.Manager
store store.Store
}
type managerMock struct {
}
func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager {
func NewManager(store store.Store) Manager {
return &managerImpl{
userManager: userManager,
settingsManager: settingsManager,
store: store,
}
}
func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
user, err := m.userManager.GetUser(ctx, userID)
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return false, err
}
if user == nil {
return false, errors.New("user not found")
return false, status.NewUserNotFoundError(userID)
}
if user.AccountID != accountID {
return false, errors.New("user does not belong to account")
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return false, err
}
switch module {
case Accounts:
if operation == Write && user.Role != types.UserRoleOwner {
return false, nil
}
return true, nil
default:
}
switch user.Role {
case types.UserRoleAdmin, types.UserRoleOwner:
return true, nil
case types.UserRoleUser:
return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation)
return m.validateRegularUserPermissions(ctx, accountID, module, operation)
case types.UserRoleBillingAdmin:
return false, nil
default:
@@ -71,8 +80,8 @@ func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, us
}
}
func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
settings, err := m.settingsManager.GetSettings(ctx, accountID, activity.SystemInitiator)
func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID string, module Module, operation Operation) (bool, error) {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, fmt.Errorf("failed to get settings: %w", err)
}
@@ -91,13 +100,30 @@ func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accoun
return false, nil
}
func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
return nil
}
func NewManagerMock() Manager {
return &managerMock{}
}
func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) {
if userID == "allowedUser" {
switch userID {
case "a23efe53-63fb-11ec-90d6-0242ac120003", "allowedUser", "testingUser", "account_creator", "serviceUserID", "test_user":
return true, nil
default:
return false, nil
}
return false, nil
}
func (m *managerMock) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
// @note managers explicitly checked this, so should the mock
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
return nil
}

View File

@@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -100,8 +100,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if user.IsRegularUser() {
@@ -148,8 +148,8 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {

View File

@@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.HasAdminPower() {
@@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.HasAdminPower() {
@@ -99,8 +99,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if !user.HasAdminPower() {
@@ -141,8 +141,8 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.HasAdminPower() {

View File

@@ -25,7 +25,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
@@ -119,16 +123,18 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() {
@@ -236,16 +242,18 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
@@ -310,6 +318,15 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
@@ -342,7 +359,11 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}

View File

@@ -21,6 +21,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -1116,14 +1117,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
assert.Len(t, peer2RoutesAfterDelete.Routes, 3, "after peer deletion group should have 3 client routes")
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have only 2 routes")
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
@@ -1134,7 +1135,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
peer2RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
assert.Len(t, peer2RoutesAfterAdd.Routes, 3, "HA route should have 3 client routes")
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID)
require.NoError(t, err)
@@ -1259,6 +1260,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
permissionsManagerMock := permissions.NewManagerMock()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -1281,7 +1283,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes().
Return(&types.ExtraSettings{}, nil)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
}
func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1492,7 +1494,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
{
ID: routeGroupHA1,
Name: routeGroupHA1,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, // we have one non Linux peer, see peer3
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
{
ID: routeGroupHA2,

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -25,13 +26,15 @@ type managerImpl struct {
store store.Store
extraSettingsManager extra_settings.Manager
userManager users.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager) Manager {
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
extraSettingsManager: extraSettingsManager,
userManager: userManager,
permissionsManager: permissionsManager,
}
}
@@ -41,13 +44,12 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
if userID != activity.SystemInitiator {
user, err := m.userManager.GetUser(ctx, userID)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Read)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
return nil, status.NewPermissionValidationError(err)
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
if !ok {
return nil, status.NewPermissionDeniedError()
}
}

View File

@@ -61,8 +61,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -118,8 +118,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -180,8 +180,8 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -198,8 +198,8 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return nil, err
}
if user.IsRegularUser() {
@@ -226,8 +226,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return err
}
if user.IsRegularUser() {

View File

@@ -183,7 +183,7 @@ func NewPermissionDeniedError() error {
}
func NewPermissionValidationError(err error) error {
return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err)
return Errorf(PermissionDenied, "failed to validate user permissions: %s", err)
}
func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {

View File

@@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error {
}
// GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() Engine {
return FileStoreEngine
func (s *FileStore) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}

View File

@@ -55,7 +55,7 @@ type SqlStore struct {
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine Engine
storeEngine types.Engine
}
type installation struct {
@@ -66,7 +66,7 @@ type installation struct {
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
@@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
conns = runtime.NumCPU()
}
if storeEngine == SqliteStoreEngine {
if storeEngine == types.SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
@@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
}
func GetKeyQueryCondition(s *SqlStore) string {
if s.storeEngine == MysqlStoreEngine {
if s.storeEngine == types.MysqlStoreEngine {
return mysqlKeyQueryCondition
}
return keyQueryCondition
@@ -586,6 +586,19 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
return users, nil
}
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "failed to get account owner from the store")
}
return &user, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
var groups []*types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
@@ -970,7 +983,7 @@ func (s *SqlStore) Close(_ context.Context) error {
}
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() Engine {
func (s *SqlStore) GetStoreEngine() types.Engine {
return s.storeEngine
}
@@ -988,7 +1001,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics)
return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics)
}
// NewPostgresqlStore creates a new Postgres store.
@@ -998,7 +1011,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics)
}
// NewMysqlStore creates a new MySQL store.
@@ -1008,7 +1021,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
return nil, err
}
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
@@ -1517,9 +1530,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
switch s.storeEngine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case MysqlStoreEngine:
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
@@ -2194,3 +2207,17 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
return &peer, nil
}
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
var count int64
result := s.db.Model(&types.Account{}).
Where("domain = ? AND domain_category = ?",
strings.ToLower(domain), types.PrivateCategory,
).Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error)
return 0, status.Errorf(status.Internal, "failed to count accounts by private domain")
}
return count, nil
}

View File

@@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) {
}
// TODO: figure out why this fails on postgres
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
@@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
@@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
@@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
}
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
}
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

View File

@@ -69,10 +69,12 @@ type Store interface {
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -164,7 +166,7 @@ type Store interface {
Close(ctx context.Context) error
// GetStoreEngine should return Engine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() Engine
GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
@@ -187,44 +189,37 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
}
type Engine string
const (
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
PostgresStoreEngine Engine = "postgres"
MysqlStoreEngine Engine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() Engine {
func getStoreEngineFromEnv() types.Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return ""
}
value := Engine(strings.ToLower(kind))
value := types.Engine(strings.ToLower(kind))
if slices.Contains(supportedEngines, value) {
return value
}
return SqliteStoreEngine
return types.SqliteStoreEngine
}
// getStoreEngine determines the store engine to use.
// If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
if kind == "" {
kind = getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
// Migrate if it is the first run with a JSON file existing and no SQLite file present
jsonStoreFile := filepath.Join(dataDir, storeFileName)
@@ -236,7 +231,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
// Attempt to migrate from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
kind = FileStoreEngine
kind = types.FileStoreEngine
}
}
}
@@ -246,7 +241,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
}
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(ctx, dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil {
@@ -254,13 +249,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
switch kind {
case SqliteStoreEngine:
case types.SqliteStoreEngine:
log.WithContext(ctx).Info("using SQLite store engine")
return NewSqliteStore(ctx, dataDir, metrics)
case PostgresStoreEngine:
case types.PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMysqlStore(ctx, metrics)
default:
@@ -268,12 +263,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
}
func checkFileStoreEngine(kind Engine, dataDir string) error {
if kind == FileStoreEngine {
func checkFileStoreEngine(kind types.Engine, dataDir string) error {
if kind == types.FileStoreEngine {
storeFile := filepath.Join(dataDir, storeFileName)
if util.FileExists(storeFile) {
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine)
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
}
}
return nil
@@ -326,7 +321,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
}
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
@@ -348,7 +343,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}
}
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
}
@@ -394,13 +389,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
return nil
}
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
var cleanup func()
var err error
switch kind {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
default:
cleanup = func() {
@@ -419,7 +414,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
return store, closeConnection, nil
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreatePostgresTestContainer()
@@ -451,7 +446,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (
return store, cleanup, nil
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreateMysqlTestContainer()
@@ -483,7 +478,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq
return store, cleanup, nil
}
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
@@ -493,9 +488,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
var err error
cleanup := func() {
switch engine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
case MysqlStoreEngine:
case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
}

View File

@@ -20,6 +20,11 @@
"Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true
},
"Relay":{
"Addresses":["rel://test.com:3535"],
"CredentialsTTL":"2h",
"Secret":"netbird"
},
"Signal": {
"Proto": "http",
"URI": "signal.netbird.io:10000",

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
@@ -32,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayCfg *Relay
turnCfg *types.TURNConfig
relayCfg *types.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
@@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
}
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
@@ -245,17 +246,18 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
},
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) {
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig
}

View File

@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &Host{
Proto: UDP,
var TurnTestHost = &types.Host{
Proto: types.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
@@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)

View File

@@ -149,11 +149,6 @@ func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enab
return enabledRoutes, disabledRoutes
}
// currently we support only linux routing peers
if peer.Meta.GoOS != "linux" {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {

View File

@@ -1,10 +1,9 @@
package server
package types
import (
"net/netip"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
@@ -30,6 +29,8 @@ const (
DefaultDeviceAuthFlowScope string = "openid"
)
var MgmtConfigPath string
// Config of the Management service
type Config struct {
Stuns []*Host
@@ -76,6 +77,7 @@ type TURNConfig struct {
Turns []*Host
}
// Relay configuration type
type Relay struct {
Addresses []string
CredentialsTTL util.Duration
@@ -156,7 +158,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine store.Engine
Engine Engine
}
// ReverseProxy contains reverse proxy configuration in front of management.

View File

@@ -0,0 +1,10 @@
package types
type Engine string
const (
PostgresStoreEngine Engine = "postgres"
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
MysqlStoreEngine Engine = "mysql"
)

View File

@@ -30,8 +30,8 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
if !initiatorUser.HasAdminPower() {
@@ -93,8 +93,8 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
inviterID := userID
@@ -142,12 +142,21 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
// createNewIdpUser validates the invite and creates a new user in the IdP
func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
inviter, err := am.GetUserByID(ctx, inviterID)
if err != nil {
return nil, fmt.Errorf("failed to get inviter user: %w", err)
}
// inviterUser is the one who is inviting the new user
inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID)
inviterUser, err := am.lookupUserInCache(ctx, inviterID, inviter.AccountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID)
}
if inviterUser == nil {
return nil, status.Errorf(status.NotFound, "inviter user with ID %s is empty", inviterID)
}
// check if the user is already registered with this email => reject
user, err := am.lookupUserInCacheByEmail(ctx, invite.Email, accountID)
if err != nil {
@@ -188,7 +197,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
if newLogin {
@@ -228,8 +237,8 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return err
}
if initiatorUser.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return err
}
if !initiatorUser.HasAdminPower() {
@@ -290,8 +299,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
return err
}
if initiatorUser.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return err
}
// check if the user is already registered with this ID
@@ -338,8 +347,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
@@ -376,8 +385,8 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return err
}
if initiatorUser.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return err
}
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
@@ -411,8 +420,8 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
@@ -429,8 +438,8 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() {
@@ -476,8 +485,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return nil, err
}
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() {
@@ -511,7 +520,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, initiatorUser, update, addIfNotExists, settings,
ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process user update: %w", err)
@@ -597,13 +606,13 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
}
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
if update == nil {
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, update, addIfNotExists)
oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
if err != nil {
return false, nil, nil, nil, err
}
@@ -614,7 +623,6 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// only auto groups, revoked status, and integration reference can be updated for now
updatedUser := oldUser.Copy()
updatedUser.AccountID = initiatorUser.AccountID
updatedUser.Role = update.Role
updatedUser.Blocked = update.Blocked
updatedUser.AutoGroups = update.AutoGroups
@@ -657,17 +665,23 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) {
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
}
update.AccountID = accountID
return update, nil // use all fields from update if addIfNotExists is true
}
return nil, err
}
if existingUser.AccountID != accountID {
return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch")
}
return existingUser, nil
}
@@ -705,6 +719,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
// @todo double check these
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
@@ -790,8 +805,8 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
return nil, err
}
if initiatorUser.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, true); err != nil {
return nil, err
}
return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers)
@@ -967,6 +982,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return err
}
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil {
return err
}
if !initiatorUser.HasAdminPower() {
return status.NewAdminPermissionError()
}
@@ -1081,6 +1100,25 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return updateAccountPeers, nil
}
// GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if owner == nil {
return nil, status.Errorf(status.NotFound, "owner not found")
}
userInfo, err := am.getUserInfo(ctx, owner, accountID)
if err != nil {
return nil, err
}
return userInfo, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {

View File

@@ -8,11 +8,11 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"golang.org/x/exp/maps"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -59,9 +59,11 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
Store: s,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
@@ -107,9 +109,11 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -133,9 +137,11 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -160,9 +166,11 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn)
@@ -183,9 +191,11 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
@@ -214,9 +224,11 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -255,9 +267,11 @@ func TestUser_GetPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -296,9 +310,11 @@ func TestUser_GetAllPATs(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -390,9 +406,11 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"})
@@ -435,9 +453,11 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
@@ -481,9 +501,11 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
@@ -510,10 +532,12 @@ func TestUser_InviteNewUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
cacheLoading: map[string]chan struct{}{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
cacheLoading: map[string]chan struct{}{},
permissionsManager: permissionsMananagerMock,
}
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
@@ -616,9 +640,11 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
@@ -652,9 +678,11 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -704,10 +732,12 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
permissionsManager: permissionsMananagerMock,
}
testCases := []struct {
@@ -812,10 +842,12 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
permissionsManager: permissionsMananagerMock,
}
testCases := []struct {
@@ -921,9 +953,11 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
claims := nbcontext.UserAuth{
@@ -957,9 +991,11 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -1044,9 +1080,11 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -1087,11 +1125,13 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
cacheLoading: map[string]chan struct{}{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
cacheLoading: map[string]chan struct{}{},
permissionsManager: permissionsMananagerMock,
}
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
@@ -1148,9 +1188,11 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
@@ -1180,9 +1222,11 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
permissionsMananagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsMananagerMock,
}
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID)
@@ -1525,3 +1569,41 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
}
func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
targetId := "user2"
account1.Users[targetId] = &types.User{
Id: targetId,
AccountID: account1.Id,
ServiceUserName: "user2username",
}
require.NoError(t, s.SaveAccount(context.Background(), account1))
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
require.NoError(t, s.SaveAccount(context.Background(), account2))
permissionsManagerMock := permissions.NewManagerMock()
am := DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
idpManager: nil,
cacheLoading: map[string]chan struct{}{},
permissionsManager: permissionsManagerMock,
}
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
assert.Equal(t, account1.Users[targetId].AutoGroups, user.AutoGroups)
}