diff --git a/management/server/account.go b/management/server/account.go index ab1ffe8b3..dfcbdbdd1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1595,7 +1595,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction } func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) + existingLabels, err := s.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerHostName) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } diff --git a/management/server/peer.go b/management/server/peer.go index 908610fbe..053381041 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -232,7 +232,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peer.Name != update.Name { - existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) + existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID, update.Name) if err != nil { return err } @@ -1467,13 +1467,13 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { - dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) +func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string, peerName string) (types.LookupMap, error) { + dnsLabels, err := transaction.GetPeerLabelsInAccountForName(ctx, store.LockingStrengthShare, accountID, peerName) if err != nil { return nil, err } - existingLabels := make(types.LookupMap) + existingLabels := make(types.LookupMap, len(dnsLabels)) for _, label := range dnsLabels { existingLabels[label] = struct{}{} } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 7d3b288e0..6867d7fac 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -873,7 +873,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) return accountID, nil } -func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]struct{}, error) { var ipJSONStrings []string // Fetch the IP addresses as JSON strings @@ -887,23 +887,22 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } - // Convert the JSON strings to net.IP objects - ips := make([]net.IP, len(ipJSONStrings)) - for i, ipJSON := range ipJSONStrings { + ips := make(map[string]struct{}, len(ipJSONStrings)) + for _, ipJSON := range ipJSONStrings { var ip net.IP if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") } - ips[i] = ip + ips[ip.String()] = struct{}{} } return ips, nil } -func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { +func (s *SqlStore) GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountID string, peerName string) ([]string, error) { var labels []string result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). - Where("account_id = ?", accountID). + Where("account_id = ? AND dns_label LIKE ?", accountID, peerName+"%"). Pluck("dns_label", &labels) if result.Error != nil { @@ -1196,12 +1195,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingSt return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } - for _, existingPeerID := range group.Peers { - if existingPeerID == peerID { - return nil - } - } - group.Peers = append(group.Peers, peerID) if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8bd8ce098..f5af21273 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -992,19 +992,20 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{}, labels) - peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } + + labels, err := store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer1.DNSLabel) require.NoError(t, err) assert.Equal(t, []string{"peer1.domain.test"}, labels) @@ -1016,7 +1017,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + labels, err = store.GetPeerLabelsInAccountForName(context.Background(), LockingStrengthShare, existingAccountID, peer2.DNSLabel) require.NoError(t, err) assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) } diff --git a/management/server/store/store.go b/management/server/store/store.go index ca332a493..223b3258d 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -115,7 +115,7 @@ type Store interface { SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetPeerLabelsInAccountForName(ctx context.Context, lockStrength LockingStrength, accountId string, dnsName string) ([]string, error) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) @@ -150,7 +150,7 @@ type Store interface { SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) (map[string]struct{}, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) diff --git a/management/server/types/network.go b/management/server/types/network.go index 00082bb41..c62b4488e 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -1,6 +1,8 @@ package types import ( + "encoding/binary" + "fmt" "math/rand" "net" "sync" @@ -13,7 +15,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -160,25 +161,74 @@ func (n *Network) Copy() *Network { // AllocatePeerIP pics an available IP from an net.IPNet. // This method considers already taken IPs and reuses IPs if there are gaps in takenIps // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 -func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { - takenIPMap := make(map[string]struct{}) - takenIPMap[ipNet.IP.String()] = struct{}{} - for _, ip := range takenIps { - takenIPMap[ip.String()] = struct{}{} +func AllocatePeerIP(ipNet net.IPNet, takenIps map[string]struct{}) (net.IP, error) { + numOfIPsInSubnet := numOfIPs(ipNet) + if len(takenIps) < numOfIPsInSubnet { + ip, err := allocateRandomFreeIP(ipNet, takenIps, numOfIPsInSubnet) + if err == nil { + return ip, nil + } + } + return allocateNextFreeIP(ipNet, takenIps, numOfIPsInSubnet) +} + +func allocateNextFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) { + ip := ipNet.IP.Mask(ipNet.Mask) + + ip4 := ip.To4() + if ip4 == nil { + return nil, fmt.Errorf("only IPv4 is supported") + } + start := binary.BigEndian.Uint32(ip4) + + for i := uint32(1); i < uint32(numIPs-1); i++ { + candidate := make(net.IP, 4) + binary.BigEndian.PutUint32(candidate, start+i) + + if _, taken := takenIps[candidate.String()]; !taken { + return candidate, nil + } } - ips, _ := generateIPs(&ipNet, takenIPMap) + return nil, fmt.Errorf("no available IPs in network %s", ipNet.String()) +} - if len(ips) == 0 { - return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) +func allocateRandomFreeIP(ipNet net.IPNet, takenIps map[string]struct{}, numIPs int) (net.IP, error) { + ip := ipNet.IP.Mask(ipNet.Mask) + ip4 := ip.To4() + if ip4 == nil { + return nil, fmt.Errorf("only IPv4 is supported") + } + start := binary.BigEndian.Uint32(ip4) + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + const maxTries = 1000 + for i := 0; i < maxTries; i++ { + randomOffset := uint32(r.Intn(numIPs-2)) + 1 + candidate := make(net.IP, 4) + binary.BigEndian.PutUint32(candidate, start+randomOffset) + + if _, taken := takenIps[candidate.String()]; !taken { + return candidate, nil + } } - // pick a random IP - s := rand.NewSource(time.Now().Unix()) - r := rand.New(s) - intn := r.Intn(len(ips)) + for i := uint32(1); i < uint32(numIPs-1); i++ { + candidate := make(net.IP, 4) + binary.BigEndian.PutUint32(candidate, start+i) + if _, taken := takenIps[candidate.String()]; !taken { + return candidate, nil + } + } - return ips[intn], nil + return nil, fmt.Errorf("failed to randomly generate ip in network %s", ipNet.String()) +} + +func numOfIPs(ipNet net.IPNet) int { + ones, bits := ipNet.Mask.Size() + numIPs := 1 << (bits - ones) + return numIPs } // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list diff --git a/management/server/types/network_test.go b/management/server/types/network_test.go index d0b0894d4..dfddd262b 100644 --- a/management/server/types/network_test.go +++ b/management/server/types/network_test.go @@ -17,23 +17,23 @@ func TestNewNetwork(t *testing.T) { func TestAllocatePeerIP(t *testing.T) { ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}} - var ips []net.IP + var ips map[string]struct{} for i := 0; i < 252; i++ { ip, err := AllocatePeerIP(ipNet, ips) if err != nil { t.Fatal(err) } - ips = append(ips, ip) + ips[ip.String()] = struct{}{} } assert.Len(t, ips, 252) uniq := make(map[string]struct{}) - for _, ip := range ips { - if _, ok := uniq[ip.String()]; !ok { - uniq[ip.String()] = struct{}{} + for ip := range ips { + if _, ok := uniq[ip]; !ok { + uniq[ip] = struct{}{} } else { - t.Errorf("found duplicate IP %s", ip.String()) + t.Errorf("found duplicate IP %s", ip) } } } @@ -49,3 +49,42 @@ func TestGenerateIPs(t *testing.T) { t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String()) } } + +func BenchmarkAllocatePeerIP(b *testing.B) { + testCase := []struct { + name string + numUsedIPs int + }{ + {"1000", 1000}, + {"10000", 10000}, + {"30000", 30000}, + {"40000", 40000}, + {"60000", 60000}, + } + network := NewNetwork() + + for _, tc := range testCase { + b.Run(tc.name, func(b *testing.B) { + usedIPs := generateUsedIPs(network.Net, tc.numUsedIPs) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := AllocatePeerIP(network.Net, usedIPs) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func generateUsedIPs(ipNet net.IPNet, numIPs int) map[string]struct{} { + usedIPs := make(map[string]struct{}, numIPs) + for i := 0; i < numIPs; i++ { + ip, err := AllocatePeerIP(ipNet, usedIPs) + if err != nil { + return nil + } + usedIPs[ip.String()] = struct{}{} + } + return usedIPs +}