mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] fix domain uniqueness (#5529)
This commit is contained in:
@@ -36,11 +36,11 @@ func TestReapExpiredExposes(t *testing.T) {
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
// Expired service should be deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "expired service should be deleted")
|
||||
|
||||
// Non-expired service should remain
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
|
||||
require.NoError(t, err, "active service should remain")
|
||||
}
|
||||
|
||||
@@ -191,14 +191,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
|
||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err, "renewed service should survive reaping")
|
||||
}
|
||||
|
||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||
t.Helper()
|
||||
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
|
||||
svc, err := s.GetServiceByDomain(context.Background(), domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
expired := time.Now().Add(-2 * exposeTTL)
|
||||
|
||||
@@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -261,8 +261,8 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
@@ -271,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
return status.Errorf(status.AlreadyExists, "domain already taken")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -352,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -805,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
svc, err := m.store.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
GetServiceByDomain(ctx, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
@@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
@@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
GetServiceByDomain(ctx, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
@@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
@@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
GetServiceByDomain(ctx, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
@@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
GetServiceByDomain(ctx, "test.com").
|
||||
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
@@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
GetServiceByDomain(ctx, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
@@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
GetServiceByDomain(ctx, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
@@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
GetServiceByDomain(ctx, "existing.com").
|
||||
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
@@ -814,7 +812,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
|
||||
|
||||
// Verify service is persisted in store
|
||||
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Domain, persisted.Domain)
|
||||
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
||||
@@ -977,7 +975,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify service is deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "service should be deleted")
|
||||
})
|
||||
|
||||
@@ -1012,7 +1010,7 @@ func TestStopServiceFromPeer(t *testing.T) {
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "service should be deleted")
|
||||
})
|
||||
}
|
||||
@@ -1031,7 +1029,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
|
||||
|
||||
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
|
||||
|
||||
@@ -134,7 +134,7 @@ type Service struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
Domain string `gorm:"type:varchar(255);uniqueIndex"`
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
@@ -535,15 +535,15 @@ var hopByHopHeaders = map[string]struct{}{
|
||||
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
|
||||
// and cannot be overridden.
|
||||
var reservedHeaders = map[string]struct{}{
|
||||
"Content-Length": {},
|
||||
"Content-Type": {},
|
||||
"Cookie": {},
|
||||
"Forwarded": {},
|
||||
"X-Forwarded-For": {},
|
||||
"X-Forwarded-Host": {},
|
||||
"X-Forwarded-Port": {},
|
||||
"X-Forwarded-Proto": {},
|
||||
"X-Real-Ip": {},
|
||||
"Content-Length": {},
|
||||
"Content-Type": {},
|
||||
"Cookie": {},
|
||||
"Forwarded": {},
|
||||
"X-Forwarded-For": {},
|
||||
"X-Forwarded-Host": {},
|
||||
"X-Forwarded-Port": {},
|
||||
"X-Forwarded-Proto": {},
|
||||
"X-Real-Ip": {},
|
||||
}
|
||||
|
||||
func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
|
||||
@@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
||||
|
||||
@@ -257,7 +257,7 @@ type Store interface {
|
||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
|
||||
@@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
|
||||
Reference in New Issue
Block a user