Compare commits

...

3 Commits

Author SHA1 Message Date
Pascal Fischer
3228a3206c modify store lock logs 2025-04-29 11:48:31 +02:00
Pascal Fischer
463d402000 improve getFreeIP and getFreeDNS [WIP] 2025-04-28 19:40:04 +02:00
Pascal Fischer
d40f60db94 add gorm tag for primary key for the networks objects 2025-04-28 14:24:17 +02:00
10 changed files with 138 additions and 52 deletions

View File

@@ -1595,7 +1595,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) existingLabels, err := s.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerHostName)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err) return "", fmt.Errorf("failed to get peer dns labels: %w", err)
} }

View File

@@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string {
} }
type NetworkResource struct { type NetworkResource struct {
ID string `gorm:"index"` ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"` NetworkID string `gorm:"index"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Name string Name string

View File

@@ -10,7 +10,7 @@ import (
) )
type NetworkRouter struct { type NetworkRouter struct {
ID string `gorm:"index"` ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"` NetworkID string `gorm:"index"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Peer string Peer string

View File

@@ -7,7 +7,7 @@ import (
) )
type Network struct { type Network struct {
ID string `gorm:"index"` ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Name string Name string
Description string Description string

View File

@@ -232,7 +232,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
if peer.Name != update.Name { if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID, update.Name)
if err != nil { if err != nil {
return err return err
} }
@@ -1467,13 +1467,13 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return groupIDs, err return groupIDs, err
} }
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string, peerName string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) dnsLabels, err := transaction.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
existingLabels := make(types.LookupMap) existingLabels := make(types.LookupMap, len(dnsLabels))
for _, label := range dnsLabels { for _, label := range dnsLabels {
existingLabels[label] = struct{}{} existingLabels[label] = struct{}{}
} }

View File

@@ -116,10 +116,11 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock") log.WithContext(ctx).Tracef("acquiring global lock")
start := time.Now() start := time.Now()
s.globalAccountLock.Lock() s.globalAccountLock.Lock()
lockTime := time.Now()
unlock = func() { unlock = func() {
s.globalAccountLock.Unlock() s.globalAccountLock.Unlock()
log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start)) log.WithContext(ctx).Tracef("released global lock: acquired in %v, hold for %v", time.Since(start), time.Since(lockTime))
} }
took := time.Since(start) took := time.Since(start)
@@ -139,10 +140,11 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.Lock() mtx.Lock()
lockTime := time.Now()
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start)) log.WithContext(ctx).Tracef("released write lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
} }
return unlock return unlock
@@ -156,10 +158,11 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.RLock() mtx.RLock()
lockTime := time.Now()
unlock = func() { unlock = func() {
mtx.RUnlock() mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start)) log.WithContext(ctx).Tracef("released read lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
} }
return unlock return unlock
@@ -873,7 +876,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
return accountID, nil return accountID, nil
} }
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]struct{}, error) {
var ipJSONStrings []string var ipJSONStrings []string
// Fetch the IP addresses as JSON strings // Fetch the IP addresses as JSON strings
@@ -887,23 +890,22 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
} }
// Convert the JSON strings to net.IP objects ips := make(map[string]struct{}, len(ipJSONStrings))
ips := make([]net.IP, len(ipJSONStrings)) for _, ipJSON := range ipJSONStrings {
for i, ipJSON := range ipJSONStrings {
var ip net.IP var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
} }
ips[i] = ip ips[ip.String()] = struct{}{}
} }
return ips, nil return ips, nil
} }
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { func (s *SqlStore) GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountID string, peerName string) ([]string, error) {
var labels []string var labels []string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ? AND dns_label LIKE ?", accountID, peerName+"%").
Pluck("dns_label", &labels) Pluck("dns_label", &labels)
if result.Error != nil { if result.Error != nil {
@@ -1196,12 +1198,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingSt
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
} }
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {

View File

@@ -992,19 +992,20 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
_, err = store.GetAccount(context.Background(), existingAccountID) _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
ID: "peer1", ID: "peer1",
AccountID: existingAccountID, AccountID: existingAccountID,
DNSLabel: "peer1.domain.test", DNSLabel: "peer1.domain.test",
} }
labels, err := store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels) assert.Equal(t, []string{"peer1.domain.test"}, labels)
@@ -1016,7 +1017,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer2.DNSLabel)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
} }

View File

@@ -115,7 +115,7 @@ type Store interface {
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountId string, dnsName string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
@@ -150,7 +150,7 @@ type Store interface {
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) (map[string]struct{}, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)

View File

@@ -1,6 +1,8 @@
package types package types
import ( import (
"encoding/binary"
"fmt"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
@@ -13,7 +15,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -160,25 +161,74 @@ func (n *Network) Copy() *Network {
// AllocatePeerIP pics an available IP from an net.IPNet. // AllocatePeerIP pics an available IP from an net.IPNet.
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps // This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { func AllocatePeerIP(ipNet net.IPNet, takenIps map[string]struct{}) (net.IP, error) {
takenIPMap := make(map[string]struct{}) numOfIPsInSubnet := numOfIPs(ipNet)
takenIPMap[ipNet.IP.String()] = struct{}{} if len(takenIps) < numOfIPsInSubnet {
for _, ip := range takenIps { ip, err := allocateRandomFreeIP(ipNet, takenIps, numOfIPsInSubnet)
takenIPMap[ip.String()] = struct{}{} if err == nil {
return ip, nil
}
}
return allocateNextFreeIP(ipNet, takenIps, numOfIPsInSubnet)
}
func allocateNextFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
ip := ipNet.IP.Mask(ipNet.Mask)
ip4 := ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("only IPv4 is supported")
}
start := binary.BigEndian.Uint32(ip4)
for i := uint32(1); i < uint32(numIPs-1); i++ {
candidate := make(net.IP, 4)
binary.BigEndian.PutUint32(candidate, start+i)
if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
} }
ips, _ := generateIPs(&ipNet, takenIPMap) return nil, fmt.Errorf("no available IPs in network %s", ipNet.String())
}
if len(ips) == 0 { func allocateRandomFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) ip := ipNet.IP.Mask(ipNet.Mask)
ip4 := ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("only IPv4 is supported")
}
start := binary.BigEndian.Uint32(ip4)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const maxTries = 1000
for i := 0; i < maxTries; i++ {
randomOffset := uint32(r.Intn(numIPs-2)) + 1
candidate := make(net.IP, 4)
binary.BigEndian.PutUint32(candidate, start+randomOffset)
if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
} }
// pick a random IP for i := uint32(1); i < uint32(numIPs-1); i++ {
s := rand.NewSource(time.Now().Unix()) candidate := make(net.IP, 4)
r := rand.New(s) binary.BigEndian.PutUint32(candidate, start+i)
intn := r.Intn(len(ips)) if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
}
return ips[intn], nil return nil, fmt.Errorf("failed to randomly generate ip in network %s", ipNet.String())
}
func numOfIPs(ipNet net.IPNet) int {
ones, bits := ipNet.Mask.Size()
numIPs := 1 << (bits - ones)
return numIPs
} }
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list

View File

@@ -17,23 +17,23 @@ func TestNewNetwork(t *testing.T) {
func TestAllocatePeerIP(t *testing.T) { func TestAllocatePeerIP(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}} ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
var ips []net.IP var ips map[string]struct{}
for i := 0; i < 252; i++ { for i := 0; i < 252; i++ {
ip, err := AllocatePeerIP(ipNet, ips) ip, err := AllocatePeerIP(ipNet, ips)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ips = append(ips, ip) ips[ip.String()] = struct{}{}
} }
assert.Len(t, ips, 252) assert.Len(t, ips, 252)
uniq := make(map[string]struct{}) uniq := make(map[string]struct{})
for _, ip := range ips { for ip := range ips {
if _, ok := uniq[ip.String()]; !ok { if _, ok := uniq[ip]; !ok {
uniq[ip.String()] = struct{}{} uniq[ip] = struct{}{}
} else { } else {
t.Errorf("found duplicate IP %s", ip.String()) t.Errorf("found duplicate IP %s", ip)
} }
} }
} }
@@ -49,3 +49,42 @@ func TestGenerateIPs(t *testing.T) {
t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String()) t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String())
} }
} }
func BenchmarkAllocatePeerIP(b *testing.B) {
testCase := []struct {
name string
numUsedIPs int
}{
{"1000", 1000},
{"10000", 10000},
{"30000", 30000},
{"40000", 40000},
{"60000", 60000},
}
network := NewNetwork()
for _, tc := range testCase {
b.Run(tc.name, func(b *testing.B) {
usedIPs := generateUsedIPs(network.Net, tc.numUsedIPs)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := AllocatePeerIP(network.Net, usedIPs)
if err != nil {
b.Fatal(err)
}
}
})
}
}
func generateUsedIPs(ipNet net.IPNet, numIPs int) map[string]struct{} {
usedIPs := make(map[string]struct{}, numIPs)
for i := 0; i < numIPs; i++ {
ip, err := AllocatePeerIP(ipNet, usedIPs)
if err != nil {
return nil
}
usedIPs[ip.String()] = struct{}{}
}
return usedIPs
}