Compare commits

...

3 Commits

Author SHA1 Message Date
Pascal Fischer
3228a3206c modify store lock logs 2025-04-29 11:48:31 +02:00
Pascal Fischer
463d402000 improve getFreeIP and getFreeDNS [WIP] 2025-04-28 19:40:04 +02:00
Pascal Fischer
d40f60db94 add gorm tag for primary key for the networks objects 2025-04-28 14:24:17 +02:00
10 changed files with 138 additions and 52 deletions

View File

@@ -1595,7 +1595,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
existingLabels, err := s.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerHostName)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}

View File

@@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string {
}
type NetworkResource struct {
ID string `gorm:"index"`
ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
Name string

View File

@@ -10,7 +10,7 @@ import (
)
type NetworkRouter struct {
ID string `gorm:"index"`
ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
Peer string

View File

@@ -7,7 +7,7 @@ import (
)
type Network struct {
ID string `gorm:"index"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Description string

View File

@@ -232,7 +232,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID, update.Name)
if err != nil {
return err
}
@@ -1467,13 +1467,13 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return groupIDs, err
}
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string, peerName string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerName)
if err != nil {
return nil, err
}
existingLabels := make(types.LookupMap)
existingLabels := make(types.LookupMap, len(dnsLabels))
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}

View File

@@ -116,10 +116,11 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
lockTime := time.Now()
unlock = func() {
s.globalAccountLock.Unlock()
log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start))
log.WithContext(ctx).Tracef("released global lock: acquired in %v, hold for %v", time.Since(start), time.Since(lockTime))
}
took := time.Since(start)
@@ -139,10 +140,11 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
lockTime := time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
log.WithContext(ctx).Tracef("released write lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
}
return unlock
@@ -156,10 +158,11 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
lockTime := time.Now()
unlock = func() {
mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
log.WithContext(ctx).Tracef("released read lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
}
return unlock
@@ -873,7 +876,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
return accountID, nil
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]struct{}, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
@@ -887,23 +890,22 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
ips := make(map[string]struct{}, len(ipJSONStrings))
for _, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
ips[ip.String()] = struct{}{}
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
func (s *SqlStore) GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountID string, peerName string) ([]string, error) {
var labels []string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Where("account_id = ? AND dns_label LIKE ?", accountID, peerName+"%").
Pluck("dns_label", &labels)
if result.Error != nil {
@@ -1196,12 +1198,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingSt
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {

View File

@@ -992,19 +992,20 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
labels, err := store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
@@ -1016,7 +1017,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer2.DNSLabel)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}

View File

@@ -115,7 +115,7 @@ type Store interface {
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountId string, dnsName string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
@@ -150,7 +150,7 @@ type Store interface {
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) (map[string]struct{}, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)

View File

@@ -1,6 +1,8 @@
package types
import (
"encoding/binary"
"fmt"
"math/rand"
"net"
"sync"
@@ -13,7 +15,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
)
@@ -160,25 +161,74 @@ func (n *Network) Copy() *Network {
// AllocatePeerIP pics an available IP from an net.IPNet.
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
takenIPMap := make(map[string]struct{})
takenIPMap[ipNet.IP.String()] = struct{}{}
for _, ip := range takenIps {
takenIPMap[ip.String()] = struct{}{}
func AllocatePeerIP(ipNet net.IPNet, takenIps map[string]struct{}) (net.IP, error) {
numOfIPsInSubnet := numOfIPs(ipNet)
if len(takenIps) < numOfIPsInSubnet {
ip, err := allocateRandomFreeIP(ipNet, takenIps, numOfIPsInSubnet)
if err == nil {
return ip, nil
}
}
return allocateNextFreeIP(ipNet, takenIps, numOfIPsInSubnet)
}
func allocateNextFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
ip := ipNet.IP.Mask(ipNet.Mask)
ip4 := ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("only IPv4 is supported")
}
start := binary.BigEndian.Uint32(ip4)
for i := uint32(1); i < uint32(numIPs-1); i++ {
candidate := make(net.IP, 4)
binary.BigEndian.PutUint32(candidate, start+i)
if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
}
ips, _ := generateIPs(&ipNet, takenIPMap)
return nil, fmt.Errorf("no available IPs in network %s", ipNet.String())
}
if len(ips) == 0 {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
func allocateRandomFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
ip := ipNet.IP.Mask(ipNet.Mask)
ip4 := ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("only IPv4 is supported")
}
start := binary.BigEndian.Uint32(ip4)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const maxTries = 1000
for i := 0; i < maxTries; i++ {
randomOffset := uint32(r.Intn(numIPs-2)) + 1
candidate := make(net.IP, 4)
binary.BigEndian.PutUint32(candidate, start+randomOffset)
if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
}
// pick a random IP
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(ips))
for i := uint32(1); i < uint32(numIPs-1); i++ {
candidate := make(net.IP, 4)
binary.BigEndian.PutUint32(candidate, start+i)
if _, taken := takenIps[candidate.String()]; !taken {
return candidate, nil
}
}
return ips[intn], nil
return nil, fmt.Errorf("failed to randomly generate ip in network %s", ipNet.String())
}
func numOfIPs(ipNet net.IPNet) int {
ones, bits := ipNet.Mask.Size()
numIPs := 1 << (bits - ones)
return numIPs
}
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list

View File

@@ -17,23 +17,23 @@ func TestNewNetwork(t *testing.T) {
func TestAllocatePeerIP(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
var ips []net.IP
var ips map[string]struct{}
for i := 0; i < 252; i++ {
ip, err := AllocatePeerIP(ipNet, ips)
if err != nil {
t.Fatal(err)
}
ips = append(ips, ip)
ips[ip.String()] = struct{}{}
}
assert.Len(t, ips, 252)
uniq := make(map[string]struct{})
for _, ip := range ips {
if _, ok := uniq[ip.String()]; !ok {
uniq[ip.String()] = struct{}{}
for ip := range ips {
if _, ok := uniq[ip]; !ok {
uniq[ip] = struct{}{}
} else {
t.Errorf("found duplicate IP %s", ip.String())
t.Errorf("found duplicate IP %s", ip)
}
}
}
@@ -49,3 +49,42 @@ func TestGenerateIPs(t *testing.T) {
t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String())
}
}
func BenchmarkAllocatePeerIP(b *testing.B) {
testCase := []struct {
name string
numUsedIPs int
}{
{"1000", 1000},
{"10000", 10000},
{"30000", 30000},
{"40000", 40000},
{"60000", 60000},
}
network := NewNetwork()
for _, tc := range testCase {
b.Run(tc.name, func(b *testing.B) {
usedIPs := generateUsedIPs(network.Net, tc.numUsedIPs)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := AllocatePeerIP(network.Net, usedIPs)
if err != nil {
b.Fatal(err)
}
}
})
}
}
func generateUsedIPs(ipNet net.IPNet, numIPs int) map[string]struct{} {
usedIPs := make(map[string]struct{}, numIPs)
for i := 0; i < numIPs; i++ {
ip, err := AllocatePeerIP(ipNet, usedIPs)
if err != nil {
return nil
}
usedIPs[ip.String()] = struct{}{}
}
return usedIPs
}