diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index bd9f4b93b..c831b4a22 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -36,11 +36,11 @@ func TestReapExpiredExposes(t *testing.T) { mgr.exposeReaper.reapExpiredExposes(ctx) // Expired service should be deleted - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "expired service should be deleted") // Non-expired service should remain - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp2.Domain) require.NoError(t, err, "active service should remain") } @@ -191,14 +191,14 @@ func TestReapSkipsRenewedService(t *testing.T) { // Reaper should skip it because the re-check sees a fresh timestamp mgr.exposeReaper.reapExpiredExposes(ctx) - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err, "renewed service should survive reaping") } // expireEphemeralService backdates meta_last_renewed_at to force expiration. func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { t.Helper() - svc, err := s.GetServiceByDomain(context.Background(), accountID, domain) + svc, err := s.GetServiceByDomain(context.Background(), domain) require.NoError(t, err) expired := time.Now().Add(-2 * exposeTTL) diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index b5e643799..56a1fc98a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil { return err } @@ -245,7 +245,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) } - if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { return err } @@ -261,8 +261,8 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee }) } -func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { - existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) +func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { + existingService, err := transaction.GetServiceByDomain(ctx, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { return fmt.Errorf("failed to check existing service: %w", err) @@ -271,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St } if existingService != nil && existingService.ID != excludeServiceID { - return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain) + return status.Errorf(status.AlreadyExists, "domain already taken") } return nil @@ -352,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se } func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil { return err } @@ -805,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, // lookupPeerService finds a peer-initiated service by domain and validates ownership. func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { - svc, err := m.store.GetServiceByDomain(ctx, accountID, domain) + svc, err := m.store.GetServiceByDomain(ctx, domain) if err != nil { return nil, err } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 196eead22..0cb8fa02a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) { func TestCheckDomainAvailable(t *testing.T) { ctx := context.Background() - accountID := "test-account" tests := []struct { name string @@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "available.com"). + GetServiceByDomain(ctx, "available.com"). Return(nil, status.Errorf(status.NotFound, "not found")) }, expectedError: false, @@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil) }, expectedError: true, @@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "service-123", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: false, @@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "service-456", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: true, @@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "error.com"). + GetServiceByDomain(ctx, "error.com"). Return(nil, errors.New("database error")) }, expectedError: true, @@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) { tt.setupMock(mockStore) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) + err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID) if tt.expectedError { require.Error(t, err) @@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) { func TestCheckDomainAvailable_EdgeCases(t *testing.T) { ctx := context.Background() - accountID := "test-account" t.Run("empty domain", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, ""). + GetServiceByDomain(ctx, ""). Return(nil, status.Errorf(status.NotFound, "not found")) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "", "") assert.NoError(t, err) }) @@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "test.com"). + GetServiceByDomain(ctx, "test.com"). Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "") assert.Error(t, err) sErr, ok := status.FromError(err) @@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "nil.com"). + GetServiceByDomain(ctx, "nil.com"). Return(nil, nil) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "") assert.NoError(t, err) }) @@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) { // Create another mock for the transaction txMock := store.NewMockStore(ctrl) txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "new.com"). + GetServiceByDomain(ctx, "new.com"). Return(nil, status.Errorf(status.NotFound, "not found")) txMock.EXPECT(). CreateService(ctx, service). @@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) { DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { txMock := store.NewMockStore(ctrl) txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "existing.com"). + GetServiceByDomain(ctx, "existing.com"). Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil) return fn(txMock) @@ -814,7 +812,7 @@ func TestCreateServiceFromPeer(t *testing.T) { assert.NotEmpty(t, resp.ServiceURL, "service URL should be set") // Verify service is persisted in store - persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err) assert.Equal(t, resp.Domain, persisted.Domain) assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral") @@ -977,7 +975,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { require.NoError(t, err) // Verify service is deleted - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "service should be deleted") }) @@ -1012,7 +1010,7 @@ func TestStopServiceFromPeer(t *testing.T) { err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) require.NoError(t, err) - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "service should be deleted") }) } @@ -1031,7 +1029,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(1), count, "one ephemeral service should exist after create") - svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err) err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index cd9311b44..bfad7fe9a 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -134,7 +134,7 @@ type Service struct { ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` Name string - Domain string `gorm:"index"` + Domain string `gorm:"type:varchar(255);uniqueIndex"` ProxyCluster string `gorm:"index"` Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` Enabled bool @@ -535,15 +535,15 @@ var hopByHopHeaders = map[string]struct{}{ // reservedHeaders are set authoritatively by the proxy or control HTTP framing // and cannot be overridden. var reservedHeaders = map[string]struct{}{ - "Content-Length": {}, - "Content-Type": {}, - "Cookie": {}, - "Forwarded": {}, - "X-Forwarded-For": {}, - "X-Forwarded-Host": {}, - "X-Forwarded-Port": {}, - "X-Forwarded-Proto": {}, - "X-Real-Ip": {}, + "Content-Length": {}, + "Content-Type": {}, + "Cookie": {}, + "Forwarded": {}, + "X-Forwarded-For": {}, + "X-Forwarded-Host": {}, + "X-Forwarded-Port": {}, + "X-Forwarded-Proto": {}, + "X-Real-Ip": {}, } func validateTargetOptions(idx int, opts *TargetOptions) error { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8f147d915..5997c10e2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren return service, nil } -func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { +func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { var service *rpservice.Service - result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) + result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain) diff --git a/management/server/store/store.go b/management/server/store/store.go index 5123cde72..1fa99fd05 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -257,7 +257,7 @@ type Store interface { UpdateService(ctx context.Context, service *rpservice.Service) error DeleteService(ctx context.Context, accountID, serviceID string) error GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) - GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) + GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 414872fbb..130df4485 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout } // GetServiceByDomain mocks base method. -func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) { +func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } // GetServiceByDomain indicates an expected call of GetServiceByDomain. -func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain) } // GetServiceByID mocks base method.