diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4d455f23d..f7dc20438 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2177,3 +2177,17 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, return nil } + +func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) { + jsonValue := fmt.Sprintf(`"%s"`, ip.String()) + + var peer nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return &peer, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index d84d699bb..44dfe6744 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -23,10 +23,9 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/migration" @@ -185,6 +184,7 @@ type Store interface { GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error + GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) } type Engine string @@ -353,12 +353,11 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( return nil, nil, fmt.Errorf("failed to create test store: %v", err) } - err = addAllGroupToAccount(ctx, store) + err = addAllGroupToAccount(ctx, store) if err != nil { return nil, nil, fmt.Errorf("failed to add all group to account: %v", err) } - maxRetries := 2 for i := 0; i < maxRetries; i++ { sqlStore, cleanUp, err := getSqlStoreEngine(ctx, store, kind)