mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Compare commits
3 Commits
feature/re
...
chore/prim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3228a3206c | ||
|
|
463d402000 | ||
|
|
d40f60db94 |
@@ -1595,7 +1595,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
|
||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
existingLabels, err := s.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerHostName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string {
|
||||
}
|
||||
|
||||
type NetworkResource struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type NetworkRouter struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
Peer string
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
|
||||
@@ -232,7 +232,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID, update.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1467,13 +1467,13 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
return groupIDs, err
|
||||
}
|
||||
|
||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
|
||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string, peerName string) (types.LookupMap, error) {
|
||||
dnsLabels, err := transaction.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingLabels := make(types.LookupMap)
|
||||
existingLabels := make(types.LookupMap, len(dnsLabels))
|
||||
for _, label := range dnsLabels {
|
||||
existingLabels[label] = struct{}{}
|
||||
}
|
||||
|
||||
@@ -116,10 +116,11 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring global lock")
|
||||
start := time.Now()
|
||||
s.globalAccountLock.Lock()
|
||||
lockTime := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
s.globalAccountLock.Unlock()
|
||||
log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released global lock: acquired in %v, hold for %v", time.Since(start), time.Since(lockTime))
|
||||
}
|
||||
|
||||
took := time.Since(start)
|
||||
@@ -139,10 +140,11 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
lockTime := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -156,10 +158,11 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
lockTime := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s: acquired in %v, hold for %v", uniqueID, time.Since(start), time.Since(lockTime))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -873,7 +876,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]struct{}, error) {
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
@@ -887,23 +890,22 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
||||
}
|
||||
|
||||
// Convert the JSON strings to net.IP objects
|
||||
ips := make([]net.IP, len(ipJSONStrings))
|
||||
for i, ipJSON := range ipJSONStrings {
|
||||
ips := make(map[string]struct{}, len(ipJSONStrings))
|
||||
for _, ipJSON := range ipJSONStrings {
|
||||
var ip net.IP
|
||||
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
||||
}
|
||||
ips[i] = ip
|
||||
ips[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
func (s *SqlStore) GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountID string, peerName string) ([]string, error) {
|
||||
var labels []string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("account_id = ? AND dns_label LIKE ?", accountID, peerName+"%").
|
||||
Pluck("dns_label", &labels)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -1196,12 +1198,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingSt
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
|
||||
@@ -992,19 +992,20 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||
|
||||
@@ -1016,7 +1017,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer2.DNSLabel)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ type Store interface {
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountId string, dnsName string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||
@@ -150,7 +150,7 @@ type Store interface {
|
||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) (map[string]struct{}, error)
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -13,7 +15,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -160,25 +161,74 @@ func (n *Network) Copy() *Network {
|
||||
// AllocatePeerIP pics an available IP from an net.IPNet.
|
||||
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
||||
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||
takenIPMap := make(map[string]struct{})
|
||||
takenIPMap[ipNet.IP.String()] = struct{}{}
|
||||
for _, ip := range takenIps {
|
||||
takenIPMap[ip.String()] = struct{}{}
|
||||
func AllocatePeerIP(ipNet net.IPNet, takenIps map[string]struct{}) (net.IP, error) {
|
||||
numOfIPsInSubnet := numOfIPs(ipNet)
|
||||
if len(takenIps) < numOfIPsInSubnet {
|
||||
ip, err := allocateRandomFreeIP(ipNet, takenIps, numOfIPsInSubnet)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
return allocateNextFreeIP(ipNet, takenIps, numOfIPsInSubnet)
|
||||
}
|
||||
|
||||
func allocateNextFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
|
||||
ip := ipNet.IP.Mask(ipNet.Mask)
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("only IPv4 is supported")
|
||||
}
|
||||
start := binary.BigEndian.Uint32(ip4)
|
||||
|
||||
for i := uint32(1); i < uint32(numIPs-1); i++ {
|
||||
candidate := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(candidate, start+i)
|
||||
|
||||
if _, taken := takenIps[candidate.String()]; !taken {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
||||
return nil, fmt.Errorf("no available IPs in network %s", ipNet.String())
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||
func allocateRandomFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) {
|
||||
ip := ipNet.IP.Mask(ipNet.Mask)
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("only IPv4 is supported")
|
||||
}
|
||||
start := binary.BigEndian.Uint32(ip4)
|
||||
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
const maxTries = 1000
|
||||
for i := 0; i < maxTries; i++ {
|
||||
randomOffset := uint32(r.Intn(numIPs-2)) + 1
|
||||
candidate := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(candidate, start+randomOffset)
|
||||
|
||||
if _, taken := takenIps[candidate.String()]; !taken {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
// pick a random IP
|
||||
s := rand.NewSource(time.Now().Unix())
|
||||
r := rand.New(s)
|
||||
intn := r.Intn(len(ips))
|
||||
for i := uint32(1); i < uint32(numIPs-1); i++ {
|
||||
candidate := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(candidate, start+i)
|
||||
if _, taken := takenIps[candidate.String()]; !taken {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
return ips[intn], nil
|
||||
return nil, fmt.Errorf("failed to randomly generate ip in network %s", ipNet.String())
|
||||
}
|
||||
|
||||
func numOfIPs(ipNet net.IPNet) int {
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
numIPs := 1 << (bits - ones)
|
||||
return numIPs
|
||||
}
|
||||
|
||||
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
||||
|
||||
@@ -17,23 +17,23 @@ func TestNewNetwork(t *testing.T) {
|
||||
|
||||
func TestAllocatePeerIP(t *testing.T) {
|
||||
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
|
||||
var ips []net.IP
|
||||
var ips map[string]struct{}
|
||||
for i := 0; i < 252; i++ {
|
||||
ip, err := AllocatePeerIP(ipNet, ips)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
ips[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
assert.Len(t, ips, 252)
|
||||
|
||||
uniq := make(map[string]struct{})
|
||||
for _, ip := range ips {
|
||||
if _, ok := uniq[ip.String()]; !ok {
|
||||
uniq[ip.String()] = struct{}{}
|
||||
for ip := range ips {
|
||||
if _, ok := uniq[ip]; !ok {
|
||||
uniq[ip] = struct{}{}
|
||||
} else {
|
||||
t.Errorf("found duplicate IP %s", ip.String())
|
||||
t.Errorf("found duplicate IP %s", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -49,3 +49,42 @@ func TestGenerateIPs(t *testing.T) {
|
||||
t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAllocatePeerIP(b *testing.B) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
numUsedIPs int
|
||||
}{
|
||||
{"1000", 1000},
|
||||
{"10000", 10000},
|
||||
{"30000", 30000},
|
||||
{"40000", 40000},
|
||||
{"60000", 60000},
|
||||
}
|
||||
network := NewNetwork()
|
||||
|
||||
for _, tc := range testCase {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
usedIPs := generateUsedIPs(network.Net, tc.numUsedIPs)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := AllocatePeerIP(network.Net, usedIPs)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateUsedIPs(ipNet net.IPNet, numIPs int) map[string]struct{} {
|
||||
usedIPs := make(map[string]struct{}, numIPs)
|
||||
for i := 0; i < numIPs; i++ {
|
||||
ip, err := AllocatePeerIP(ipNet, usedIPs)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
usedIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
return usedIPs
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user